anthropic.rs

   1use crate::AllLanguageModelSettings;
   2use crate::ui::InstructionListItem;
   3use anthropic::{
   4    AnthropicError, AnthropicModelMode, ContentDelta, Event, ResponseContent, ToolResultContent,
   5    ToolResultPart, Usage,
   6};
   7use anyhow::{Context as _, Result, anyhow};
   8use collections::{BTreeMap, HashMap};
   9use credentials_provider::CredentialsProvider;
  10use editor::{Editor, EditorElement, EditorStyle};
  11use futures::Stream;
  12use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
  13use gpui::{
  14    AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
  15};
  16use http_client::HttpClient;
  17use language_model::{
  18    AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
  19    LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelId,
  20    LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
  21    LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
  22    LanguageModelToolResultContent, MessageContent, RateLimiter, Role,
  23};
  24use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
  25use schemars::JsonSchema;
  26use serde::{Deserialize, Serialize};
  27use settings::{Settings, SettingsStore};
  28use std::pin::Pin;
  29use std::str::FromStr;
  30use std::sync::Arc;
  31use strum::IntoEnumIterator;
  32use theme::ThemeSettings;
  33use ui::{Icon, IconName, List, Tooltip, prelude::*};
  34use util::ResultExt;
  35
  36const PROVIDER_ID: LanguageModelProviderId = language_model::ANTHROPIC_PROVIDER_ID;
  37const PROVIDER_NAME: LanguageModelProviderName = language_model::ANTHROPIC_PROVIDER_NAME;
  38
  39#[derive(Default, Clone, Debug, PartialEq)]
  40pub struct AnthropicSettings {
  41    pub api_url: String,
  42    /// Extend Zed's list of Anthropic models.
  43    pub available_models: Vec<AvailableModel>,
  44}
  45
  46#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  47pub struct AvailableModel {
  48    /// The model's name in the Anthropic API. e.g. claude-3-5-sonnet-latest, claude-3-opus-20240229, etc
  49    pub name: String,
  50    /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
  51    pub display_name: Option<String>,
  52    /// The model's context window size.
  53    pub max_tokens: u64,
  54    /// A model `name` to substitute when calling tools, in case the primary model doesn't support tool calling.
  55    pub tool_override: Option<String>,
  56    /// Configuration of Anthropic's caching API.
  57    pub cache_configuration: Option<LanguageModelCacheConfiguration>,
  58    pub max_output_tokens: Option<u64>,
  59    pub default_temperature: Option<f32>,
  60    #[serde(default)]
  61    pub extra_beta_headers: Vec<String>,
  62    /// The model's mode (e.g. thinking)
  63    pub mode: Option<ModelMode>,
  64}
  65
  66#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
  67#[serde(tag = "type", rename_all = "lowercase")]
  68pub enum ModelMode {
  69    #[default]
  70    Default,
  71    Thinking {
  72        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  73        budget_tokens: Option<u32>,
  74    },
  75}
  76
  77impl From<ModelMode> for AnthropicModelMode {
  78    fn from(value: ModelMode) -> Self {
  79        match value {
  80            ModelMode::Default => AnthropicModelMode::Default,
  81            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
  82        }
  83    }
  84}
  85
  86impl From<AnthropicModelMode> for ModelMode {
  87    fn from(value: AnthropicModelMode) -> Self {
  88        match value {
  89            AnthropicModelMode::Default => ModelMode::Default,
  90            AnthropicModelMode::Thinking { budget_tokens } => ModelMode::Thinking { budget_tokens },
  91        }
  92    }
  93}
  94
  95pub struct AnthropicLanguageModelProvider {
  96    http_client: Arc<dyn HttpClient>,
  97    state: gpui::Entity<State>,
  98}
  99
 100const ANTHROPIC_API_KEY_VAR: &str = "ANTHROPIC_API_KEY";
 101
 102pub struct State {
 103    api_key: Option<String>,
 104    api_key_from_env: bool,
 105    _subscription: Subscription,
 106}
 107
 108impl State {
 109    fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 110        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 111        let api_url = AllLanguageModelSettings::get_global(cx)
 112            .anthropic
 113            .api_url
 114            .clone();
 115        cx.spawn(async move |this, cx| {
 116            credentials_provider
 117                .delete_credentials(&api_url, cx)
 118                .await
 119                .ok();
 120            this.update(cx, |this, cx| {
 121                this.api_key = None;
 122                this.api_key_from_env = false;
 123                cx.notify();
 124            })
 125        })
 126    }
 127
 128    fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
 129        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 130        let api_url = AllLanguageModelSettings::get_global(cx)
 131            .anthropic
 132            .api_url
 133            .clone();
 134        cx.spawn(async move |this, cx| {
 135            credentials_provider
 136                .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
 137                .await
 138                .ok();
 139
 140            this.update(cx, |this, cx| {
 141                this.api_key = Some(api_key);
 142                cx.notify();
 143            })
 144        })
 145    }
 146
 147    fn is_authenticated(&self) -> bool {
 148        self.api_key.is_some()
 149    }
 150
 151    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 152        if self.is_authenticated() {
 153            return Task::ready(Ok(()));
 154        }
 155
 156        let key = AnthropicLanguageModelProvider::api_key(cx);
 157
 158        cx.spawn(async move |this, cx| {
 159            let key = key.await?;
 160
 161            this.update(cx, |this, cx| {
 162                this.api_key = Some(key.key);
 163                this.api_key_from_env = key.from_env;
 164                cx.notify();
 165            })?;
 166
 167            Ok(())
 168        })
 169    }
 170}
 171
 172pub struct ApiKey {
 173    pub key: String,
 174    pub from_env: bool,
 175}
 176
 177impl AnthropicLanguageModelProvider {
 178    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
 179        let state = cx.new(|cx| State {
 180            api_key: None,
 181            api_key_from_env: false,
 182            _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 183                cx.notify();
 184            }),
 185        });
 186
 187        Self { http_client, state }
 188    }
 189
 190    fn create_language_model(&self, model: anthropic::Model) -> Arc<dyn LanguageModel> {
 191        Arc::new(AnthropicModel {
 192            id: LanguageModelId::from(model.id().to_string()),
 193            model,
 194            state: self.state.clone(),
 195            http_client: self.http_client.clone(),
 196            request_limiter: RateLimiter::new(4),
 197        })
 198    }
 199
 200    pub fn api_key(cx: &mut App) -> Task<Result<ApiKey, AuthenticateError>> {
 201        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 202        let api_url = AllLanguageModelSettings::get_global(cx)
 203            .anthropic
 204            .api_url
 205            .clone();
 206
 207        if let Ok(key) = std::env::var(ANTHROPIC_API_KEY_VAR) {
 208            Task::ready(Ok(ApiKey {
 209                key,
 210                from_env: true,
 211            }))
 212        } else {
 213            cx.spawn(async move |cx| {
 214                let (_, api_key) = credentials_provider
 215                    .read_credentials(&api_url, cx)
 216                    .await?
 217                    .ok_or(AuthenticateError::CredentialsNotFound)?;
 218
 219                Ok(ApiKey {
 220                    key: String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
 221                    from_env: false,
 222                })
 223            })
 224        }
 225    }
 226}
 227
 228impl LanguageModelProviderState for AnthropicLanguageModelProvider {
 229    type ObservableEntity = State;
 230
 231    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
 232        Some(self.state.clone())
 233    }
 234}
 235
 236impl LanguageModelProvider for AnthropicLanguageModelProvider {
 237    fn id(&self) -> LanguageModelProviderId {
 238        PROVIDER_ID
 239    }
 240
 241    fn name(&self) -> LanguageModelProviderName {
 242        PROVIDER_NAME
 243    }
 244
 245    fn icon(&self) -> IconName {
 246        IconName::AiAnthropic
 247    }
 248
 249    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 250        Some(self.create_language_model(anthropic::Model::default()))
 251    }
 252
 253    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 254        Some(self.create_language_model(anthropic::Model::default_fast()))
 255    }
 256
 257    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 258        [
 259            anthropic::Model::ClaudeSonnet4,
 260            anthropic::Model::ClaudeSonnet4Thinking,
 261        ]
 262        .into_iter()
 263        .map(|model| self.create_language_model(model))
 264        .collect()
 265    }
 266
 267    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 268        let mut models = BTreeMap::default();
 269
 270        // Add base models from anthropic::Model::iter()
 271        for model in anthropic::Model::iter() {
 272            if !matches!(model, anthropic::Model::Custom { .. }) {
 273                models.insert(model.id().to_string(), model);
 274            }
 275        }
 276
 277        // Override with available models from settings
 278        for model in AllLanguageModelSettings::get_global(cx)
 279            .anthropic
 280            .available_models
 281            .iter()
 282        {
 283            models.insert(
 284                model.name.clone(),
 285                anthropic::Model::Custom {
 286                    name: model.name.clone(),
 287                    display_name: model.display_name.clone(),
 288                    max_tokens: model.max_tokens,
 289                    tool_override: model.tool_override.clone(),
 290                    cache_configuration: model.cache_configuration.as_ref().map(|config| {
 291                        anthropic::AnthropicModelCacheConfiguration {
 292                            max_cache_anchors: config.max_cache_anchors,
 293                            should_speculate: config.should_speculate,
 294                            min_total_token: config.min_total_token,
 295                        }
 296                    }),
 297                    max_output_tokens: model.max_output_tokens,
 298                    default_temperature: model.default_temperature,
 299                    extra_beta_headers: model.extra_beta_headers.clone(),
 300                    mode: model.mode.clone().unwrap_or_default().into(),
 301                },
 302            );
 303        }
 304
 305        models
 306            .into_values()
 307            .map(|model| self.create_language_model(model))
 308            .collect()
 309    }
 310
 311    fn is_authenticated(&self, cx: &App) -> bool {
 312        self.state.read(cx).is_authenticated()
 313    }
 314
 315    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 316        self.state.update(cx, |state, cx| state.authenticate(cx))
 317    }
 318
 319    fn configuration_view(
 320        &self,
 321        target_agent: ConfigurationViewTargetAgent,
 322        window: &mut Window,
 323        cx: &mut App,
 324    ) -> AnyView {
 325        cx.new(|cx| ConfigurationView::new(self.state.clone(), target_agent, window, cx))
 326            .into()
 327    }
 328
 329    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 330        self.state.update(cx, |state, cx| state.reset_api_key(cx))
 331    }
 332}
 333
 334pub struct AnthropicModel {
 335    id: LanguageModelId,
 336    model: anthropic::Model,
 337    state: gpui::Entity<State>,
 338    http_client: Arc<dyn HttpClient>,
 339    request_limiter: RateLimiter,
 340}
 341
 342pub fn count_anthropic_tokens(
 343    request: LanguageModelRequest,
 344    cx: &App,
 345) -> BoxFuture<'static, Result<u64>> {
 346    cx.background_spawn(async move {
 347        let messages = request.messages;
 348        let mut tokens_from_images = 0;
 349        let mut string_messages = Vec::with_capacity(messages.len());
 350
 351        for message in messages {
 352            use language_model::MessageContent;
 353
 354            let mut string_contents = String::new();
 355
 356            for content in message.content {
 357                match content {
 358                    MessageContent::Text(text) => {
 359                        string_contents.push_str(&text);
 360                    }
 361                    MessageContent::Thinking { .. } => {
 362                        // Thinking blocks are not included in the input token count.
 363                    }
 364                    MessageContent::RedactedThinking(_) => {
 365                        // Thinking blocks are not included in the input token count.
 366                    }
 367                    MessageContent::Image(image) => {
 368                        tokens_from_images += image.estimate_tokens();
 369                    }
 370                    MessageContent::ToolUse(_tool_use) => {
 371                        // TODO: Estimate token usage from tool uses.
 372                    }
 373                    MessageContent::ToolResult(tool_result) => match &tool_result.content {
 374                        LanguageModelToolResultContent::Text(text) => {
 375                            string_contents.push_str(text);
 376                        }
 377                        LanguageModelToolResultContent::Image(image) => {
 378                            tokens_from_images += image.estimate_tokens();
 379                        }
 380                    },
 381                }
 382            }
 383
 384            if !string_contents.is_empty() {
 385                string_messages.push(tiktoken_rs::ChatCompletionRequestMessage {
 386                    role: match message.role {
 387                        Role::User => "user".into(),
 388                        Role::Assistant => "assistant".into(),
 389                        Role::System => "system".into(),
 390                    },
 391                    content: Some(string_contents),
 392                    name: None,
 393                    function_call: None,
 394                });
 395            }
 396        }
 397
 398        // Tiktoken doesn't yet support these models, so we manually use the
 399        // same tokenizer as GPT-4.
 400        tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
 401            .map(|tokens| (tokens + tokens_from_images) as u64)
 402    })
 403    .boxed()
 404}
 405
 406impl AnthropicModel {
 407    fn stream_completion(
 408        &self,
 409        request: anthropic::Request,
 410        cx: &AsyncApp,
 411    ) -> BoxFuture<
 412        'static,
 413        Result<
 414            BoxStream<'static, Result<anthropic::Event, AnthropicError>>,
 415            LanguageModelCompletionError,
 416        >,
 417    > {
 418        let http_client = self.http_client.clone();
 419
 420        let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
 421            let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
 422            (state.api_key.clone(), settings.api_url.clone())
 423        }) else {
 424            return futures::future::ready(Err(anyhow!("App state dropped").into())).boxed();
 425        };
 426
 427        let beta_headers = self.model.beta_headers();
 428
 429        async move {
 430            let Some(api_key) = api_key else {
 431                return Err(LanguageModelCompletionError::NoApiKey {
 432                    provider: PROVIDER_NAME,
 433                });
 434            };
 435            let request = anthropic::stream_completion(
 436                http_client.as_ref(),
 437                &api_url,
 438                &api_key,
 439                request,
 440                beta_headers,
 441            );
 442            request.await.map_err(Into::into)
 443        }
 444        .boxed()
 445    }
 446}
 447
 448impl LanguageModel for AnthropicModel {
 449    fn id(&self) -> LanguageModelId {
 450        self.id.clone()
 451    }
 452
 453    fn name(&self) -> LanguageModelName {
 454        LanguageModelName::from(self.model.display_name().to_string())
 455    }
 456
 457    fn provider_id(&self) -> LanguageModelProviderId {
 458        PROVIDER_ID
 459    }
 460
 461    fn provider_name(&self) -> LanguageModelProviderName {
 462        PROVIDER_NAME
 463    }
 464
 465    fn supports_tools(&self) -> bool {
 466        true
 467    }
 468
 469    fn supports_images(&self) -> bool {
 470        true
 471    }
 472
 473    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 474        match choice {
 475            LanguageModelToolChoice::Auto
 476            | LanguageModelToolChoice::Any
 477            | LanguageModelToolChoice::None => true,
 478        }
 479    }
 480
 481    fn telemetry_id(&self) -> String {
 482        format!("anthropic/{}", self.model.id())
 483    }
 484
 485    fn api_key(&self, cx: &App) -> Option<String> {
 486        self.state.read(cx).api_key.clone()
 487    }
 488
 489    fn max_token_count(&self) -> u64 {
 490        self.model.max_token_count()
 491    }
 492
 493    fn max_output_tokens(&self) -> Option<u64> {
 494        Some(self.model.max_output_tokens())
 495    }
 496
 497    fn count_tokens(
 498        &self,
 499        request: LanguageModelRequest,
 500        cx: &App,
 501    ) -> BoxFuture<'static, Result<u64>> {
 502        count_anthropic_tokens(request, cx)
 503    }
 504
 505    fn stream_completion(
 506        &self,
 507        request: LanguageModelRequest,
 508        cx: &AsyncApp,
 509    ) -> BoxFuture<
 510        'static,
 511        Result<
 512            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 513            LanguageModelCompletionError,
 514        >,
 515    > {
 516        let request = into_anthropic(
 517            request,
 518            self.model.request_id().into(),
 519            self.model.default_temperature(),
 520            self.model.max_output_tokens(),
 521            self.model.mode(),
 522        );
 523        let request = self.stream_completion(request, cx);
 524        let future = self.request_limiter.stream(async move {
 525            let response = request.await?;
 526            Ok(AnthropicEventMapper::new().map_stream(response))
 527        });
 528        async move { Ok(future.await?.boxed()) }.boxed()
 529    }
 530
 531    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 532        self.model
 533            .cache_configuration()
 534            .map(|config| LanguageModelCacheConfiguration {
 535                max_cache_anchors: config.max_cache_anchors,
 536                should_speculate: config.should_speculate,
 537                min_total_token: config.min_total_token,
 538            })
 539    }
 540}
 541
 542pub fn into_anthropic(
 543    request: LanguageModelRequest,
 544    model: String,
 545    default_temperature: f32,
 546    max_output_tokens: u64,
 547    mode: AnthropicModelMode,
 548) -> anthropic::Request {
 549    let mut new_messages: Vec<anthropic::Message> = Vec::new();
 550    let mut system_message = String::new();
 551
 552    for message in request.messages {
 553        if message.contents_empty() {
 554            continue;
 555        }
 556
 557        match message.role {
 558            Role::User | Role::Assistant => {
 559                let mut anthropic_message_content: Vec<anthropic::RequestContent> = message
 560                    .content
 561                    .into_iter()
 562                    .filter_map(|content| match content {
 563                        MessageContent::Text(text) => {
 564                            let text = if text.chars().last().is_some_and(|c| c.is_whitespace()) {
 565                                text.trim_end().to_string()
 566                            } else {
 567                                text
 568                            };
 569                            if !text.is_empty() {
 570                                Some(anthropic::RequestContent::Text {
 571                                    text,
 572                                    cache_control: None,
 573                                })
 574                            } else {
 575                                None
 576                            }
 577                        }
 578                        MessageContent::Thinking {
 579                            text: thinking,
 580                            signature,
 581                        } => {
 582                            if !thinking.is_empty() {
 583                                Some(anthropic::RequestContent::Thinking {
 584                                    thinking,
 585                                    signature: signature.unwrap_or_default(),
 586                                    cache_control: None,
 587                                })
 588                            } else {
 589                                None
 590                            }
 591                        }
 592                        MessageContent::RedactedThinking(data) => {
 593                            if !data.is_empty() {
 594                                Some(anthropic::RequestContent::RedactedThinking { data })
 595                            } else {
 596                                None
 597                            }
 598                        }
 599                        MessageContent::Image(image) => Some(anthropic::RequestContent::Image {
 600                            source: anthropic::ImageSource {
 601                                source_type: "base64".to_string(),
 602                                media_type: "image/png".to_string(),
 603                                data: image.source.to_string(),
 604                            },
 605                            cache_control: None,
 606                        }),
 607                        MessageContent::ToolUse(tool_use) => {
 608                            Some(anthropic::RequestContent::ToolUse {
 609                                id: tool_use.id.to_string(),
 610                                name: tool_use.name.to_string(),
 611                                input: tool_use.input,
 612                                cache_control: None,
 613                            })
 614                        }
 615                        MessageContent::ToolResult(tool_result) => {
 616                            Some(anthropic::RequestContent::ToolResult {
 617                                tool_use_id: tool_result.tool_use_id.to_string(),
 618                                is_error: tool_result.is_error,
 619                                content: match tool_result.content {
 620                                    LanguageModelToolResultContent::Text(text) => {
 621                                        ToolResultContent::Plain(text.to_string())
 622                                    }
 623                                    LanguageModelToolResultContent::Image(image) => {
 624                                        ToolResultContent::Multipart(vec![ToolResultPart::Image {
 625                                            source: anthropic::ImageSource {
 626                                                source_type: "base64".to_string(),
 627                                                media_type: "image/png".to_string(),
 628                                                data: image.source.to_string(),
 629                                            },
 630                                        }])
 631                                    }
 632                                },
 633                                cache_control: None,
 634                            })
 635                        }
 636                    })
 637                    .collect();
 638                let anthropic_role = match message.role {
 639                    Role::User => anthropic::Role::User,
 640                    Role::Assistant => anthropic::Role::Assistant,
 641                    Role::System => unreachable!("System role should never occur here"),
 642                };
 643                if let Some(last_message) = new_messages.last_mut()
 644                    && last_message.role == anthropic_role
 645                {
 646                    last_message.content.extend(anthropic_message_content);
 647                    continue;
 648                }
 649
 650                // Mark the last segment of the message as cached
 651                if message.cache {
 652                    let cache_control_value = Some(anthropic::CacheControl {
 653                        cache_type: anthropic::CacheControlType::Ephemeral,
 654                    });
 655                    for message_content in anthropic_message_content.iter_mut().rev() {
 656                        match message_content {
 657                            anthropic::RequestContent::RedactedThinking { .. } => {
 658                                // Caching is not possible, fallback to next message
 659                            }
 660                            anthropic::RequestContent::Text { cache_control, .. }
 661                            | anthropic::RequestContent::Thinking { cache_control, .. }
 662                            | anthropic::RequestContent::Image { cache_control, .. }
 663                            | anthropic::RequestContent::ToolUse { cache_control, .. }
 664                            | anthropic::RequestContent::ToolResult { cache_control, .. } => {
 665                                *cache_control = cache_control_value;
 666                                break;
 667                            }
 668                        }
 669                    }
 670                }
 671
 672                new_messages.push(anthropic::Message {
 673                    role: anthropic_role,
 674                    content: anthropic_message_content,
 675                });
 676            }
 677            Role::System => {
 678                if !system_message.is_empty() {
 679                    system_message.push_str("\n\n");
 680                }
 681                system_message.push_str(&message.string_contents());
 682            }
 683        }
 684    }
 685
 686    anthropic::Request {
 687        model,
 688        messages: new_messages,
 689        max_tokens: max_output_tokens,
 690        system: if system_message.is_empty() {
 691            None
 692        } else {
 693            Some(anthropic::StringOrContents::String(system_message))
 694        },
 695        thinking: if request.thinking_allowed
 696            && let AnthropicModelMode::Thinking { budget_tokens } = mode
 697        {
 698            Some(anthropic::Thinking::Enabled { budget_tokens })
 699        } else {
 700            None
 701        },
 702        tools: request
 703            .tools
 704            .into_iter()
 705            .map(|tool| anthropic::Tool {
 706                name: tool.name,
 707                description: tool.description,
 708                input_schema: tool.input_schema,
 709            })
 710            .collect(),
 711        tool_choice: request.tool_choice.map(|choice| match choice {
 712            LanguageModelToolChoice::Auto => anthropic::ToolChoice::Auto,
 713            LanguageModelToolChoice::Any => anthropic::ToolChoice::Any,
 714            LanguageModelToolChoice::None => anthropic::ToolChoice::None,
 715        }),
 716        metadata: None,
 717        stop_sequences: Vec::new(),
 718        temperature: request.temperature.or(Some(default_temperature)),
 719        top_k: None,
 720        top_p: None,
 721    }
 722}
 723
 724pub struct AnthropicEventMapper {
 725    tool_uses_by_index: HashMap<usize, RawToolUse>,
 726    usage: Usage,
 727    stop_reason: StopReason,
 728}
 729
 730impl AnthropicEventMapper {
 731    pub fn new() -> Self {
 732        Self {
 733            tool_uses_by_index: HashMap::default(),
 734            usage: Usage::default(),
 735            stop_reason: StopReason::EndTurn,
 736        }
 737    }
 738
 739    pub fn map_stream(
 740        mut self,
 741        events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
 742    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 743    {
 744        events.flat_map(move |event| {
 745            futures::stream::iter(match event {
 746                Ok(event) => self.map_event(event),
 747                Err(error) => vec![Err(error.into())],
 748            })
 749        })
 750    }
 751
 752    pub fn map_event(
 753        &mut self,
 754        event: Event,
 755    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
 756        match event {
 757            Event::ContentBlockStart {
 758                index,
 759                content_block,
 760            } => match content_block {
 761                ResponseContent::Text { text } => {
 762                    vec![Ok(LanguageModelCompletionEvent::Text(text))]
 763                }
 764                ResponseContent::Thinking { thinking } => {
 765                    vec![Ok(LanguageModelCompletionEvent::Thinking {
 766                        text: thinking,
 767                        signature: None,
 768                    })]
 769                }
 770                ResponseContent::RedactedThinking { data } => {
 771                    vec![Ok(LanguageModelCompletionEvent::RedactedThinking { data })]
 772                }
 773                ResponseContent::ToolUse { id, name, .. } => {
 774                    self.tool_uses_by_index.insert(
 775                        index,
 776                        RawToolUse {
 777                            id,
 778                            name,
 779                            input_json: String::new(),
 780                        },
 781                    );
 782                    Vec::new()
 783                }
 784            },
 785            Event::ContentBlockDelta { index, delta } => match delta {
 786                ContentDelta::TextDelta { text } => {
 787                    vec![Ok(LanguageModelCompletionEvent::Text(text))]
 788                }
 789                ContentDelta::ThinkingDelta { thinking } => {
 790                    vec![Ok(LanguageModelCompletionEvent::Thinking {
 791                        text: thinking,
 792                        signature: None,
 793                    })]
 794                }
 795                ContentDelta::SignatureDelta { signature } => {
 796                    vec![Ok(LanguageModelCompletionEvent::Thinking {
 797                        text: "".to_string(),
 798                        signature: Some(signature),
 799                    })]
 800                }
 801                ContentDelta::InputJsonDelta { partial_json } => {
 802                    if let Some(tool_use) = self.tool_uses_by_index.get_mut(&index) {
 803                        tool_use.input_json.push_str(&partial_json);
 804
 805                        // Try to convert invalid (incomplete) JSON into
 806                        // valid JSON that serde can accept, e.g. by closing
 807                        // unclosed delimiters. This way, we can update the
 808                        // UI with whatever has been streamed back so far.
 809                        if let Ok(input) = serde_json::Value::from_str(
 810                            &partial_json_fixer::fix_json(&tool_use.input_json),
 811                        ) {
 812                            return vec![Ok(LanguageModelCompletionEvent::ToolUse(
 813                                LanguageModelToolUse {
 814                                    id: tool_use.id.clone().into(),
 815                                    name: tool_use.name.clone().into(),
 816                                    is_input_complete: false,
 817                                    raw_input: tool_use.input_json.clone(),
 818                                    input,
 819                                },
 820                            ))];
 821                        }
 822                    }
 823                    vec![]
 824                }
 825            },
 826            Event::ContentBlockStop { index } => {
 827                if let Some(tool_use) = self.tool_uses_by_index.remove(&index) {
 828                    let input_json = tool_use.input_json.trim();
 829                    let input_value = if input_json.is_empty() {
 830                        Ok(serde_json::Value::Object(serde_json::Map::default()))
 831                    } else {
 832                        serde_json::Value::from_str(input_json)
 833                    };
 834                    let event_result = match input_value {
 835                        Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
 836                            LanguageModelToolUse {
 837                                id: tool_use.id.into(),
 838                                name: tool_use.name.into(),
 839                                is_input_complete: true,
 840                                input,
 841                                raw_input: tool_use.input_json.clone(),
 842                            },
 843                        )),
 844                        Err(json_parse_err) => {
 845                            Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
 846                                id: tool_use.id.into(),
 847                                tool_name: tool_use.name.into(),
 848                                raw_input: input_json.into(),
 849                                json_parse_error: json_parse_err.to_string(),
 850                            })
 851                        }
 852                    };
 853
 854                    vec![event_result]
 855                } else {
 856                    Vec::new()
 857                }
 858            }
 859            Event::MessageStart { message } => {
 860                update_usage(&mut self.usage, &message.usage);
 861                vec![
 862                    Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage(
 863                        &self.usage,
 864                    ))),
 865                    Ok(LanguageModelCompletionEvent::StartMessage {
 866                        message_id: message.id,
 867                    }),
 868                ]
 869            }
 870            Event::MessageDelta { delta, usage } => {
 871                update_usage(&mut self.usage, &usage);
 872                if let Some(stop_reason) = delta.stop_reason.as_deref() {
 873                    self.stop_reason = match stop_reason {
 874                        "end_turn" => StopReason::EndTurn,
 875                        "max_tokens" => StopReason::MaxTokens,
 876                        "tool_use" => StopReason::ToolUse,
 877                        "refusal" => StopReason::Refusal,
 878                        _ => {
 879                            log::error!("Unexpected anthropic stop_reason: {stop_reason}");
 880                            StopReason::EndTurn
 881                        }
 882                    };
 883                }
 884                vec![Ok(LanguageModelCompletionEvent::UsageUpdate(
 885                    convert_usage(&self.usage),
 886                ))]
 887            }
 888            Event::MessageStop => {
 889                vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))]
 890            }
 891            Event::Error { error } => {
 892                vec![Err(error.into())]
 893            }
 894            _ => Vec::new(),
 895        }
 896    }
 897}
 898
 899struct RawToolUse {
 900    id: String,
 901    name: String,
 902    input_json: String,
 903}
 904
 905/// Updates usage data by preferring counts from `new`.
 906fn update_usage(usage: &mut Usage, new: &Usage) {
 907    if let Some(input_tokens) = new.input_tokens {
 908        usage.input_tokens = Some(input_tokens);
 909    }
 910    if let Some(output_tokens) = new.output_tokens {
 911        usage.output_tokens = Some(output_tokens);
 912    }
 913    if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens {
 914        usage.cache_creation_input_tokens = Some(cache_creation_input_tokens);
 915    }
 916    if let Some(cache_read_input_tokens) = new.cache_read_input_tokens {
 917        usage.cache_read_input_tokens = Some(cache_read_input_tokens);
 918    }
 919}
 920
 921fn convert_usage(usage: &Usage) -> language_model::TokenUsage {
 922    language_model::TokenUsage {
 923        input_tokens: usage.input_tokens.unwrap_or(0),
 924        output_tokens: usage.output_tokens.unwrap_or(0),
 925        cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
 926        cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
 927    }
 928}
 929
 930struct ConfigurationView {
 931    api_key_editor: Entity<Editor>,
 932    state: gpui::Entity<State>,
 933    load_credentials_task: Option<Task<()>>,
 934    target_agent: ConfigurationViewTargetAgent,
 935}
 936
 937impl ConfigurationView {
 938    const PLACEHOLDER_TEXT: &'static str = "sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
 939
 940    fn new(
 941        state: gpui::Entity<State>,
 942        target_agent: ConfigurationViewTargetAgent,
 943        window: &mut Window,
 944        cx: &mut Context<Self>,
 945    ) -> Self {
 946        cx.observe(&state, |_, _, cx| {
 947            cx.notify();
 948        })
 949        .detach();
 950
 951        let load_credentials_task = Some(cx.spawn({
 952            let state = state.clone();
 953            async move |this, cx| {
 954                if let Some(task) = state
 955                    .update(cx, |state, cx| state.authenticate(cx))
 956                    .log_err()
 957                {
 958                    // We don't log an error, because "not signed in" is also an error.
 959                    let _ = task.await;
 960                }
 961                this.update(cx, |this, cx| {
 962                    this.load_credentials_task = None;
 963                    cx.notify();
 964                })
 965                .log_err();
 966            }
 967        }));
 968
 969        Self {
 970            api_key_editor: cx.new(|cx| {
 971                let mut editor = Editor::single_line(window, cx);
 972                editor.set_placeholder_text(Self::PLACEHOLDER_TEXT, window, cx);
 973                editor
 974            }),
 975            state,
 976            load_credentials_task,
 977            target_agent,
 978        }
 979    }
 980
 981    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
 982        let api_key = self.api_key_editor.read(cx).text(cx);
 983        if api_key.is_empty() {
 984            return;
 985        }
 986
 987        let state = self.state.clone();
 988        cx.spawn_in(window, async move |_, cx| {
 989            state
 990                .update(cx, |state, cx| state.set_api_key(api_key, cx))?
 991                .await
 992        })
 993        .detach_and_log_err(cx);
 994
 995        cx.notify();
 996    }
 997
 998    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 999        self.api_key_editor
1000            .update(cx, |editor, cx| editor.set_text("", window, cx));
1001
1002        let state = self.state.clone();
1003        cx.spawn_in(window, async move |_, cx| {
1004            state.update(cx, |state, cx| state.reset_api_key(cx))?.await
1005        })
1006        .detach_and_log_err(cx);
1007
1008        cx.notify();
1009    }
1010
1011    fn render_api_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
1012        let settings = ThemeSettings::get_global(cx);
1013        let text_style = TextStyle {
1014            color: cx.theme().colors().text,
1015            font_family: settings.ui_font.family.clone(),
1016            font_features: settings.ui_font.features.clone(),
1017            font_fallbacks: settings.ui_font.fallbacks.clone(),
1018            font_size: rems(0.875).into(),
1019            font_weight: settings.ui_font.weight,
1020            font_style: FontStyle::Normal,
1021            line_height: relative(1.3),
1022            white_space: WhiteSpace::Normal,
1023            ..Default::default()
1024        };
1025        EditorElement::new(
1026            &self.api_key_editor,
1027            EditorStyle {
1028                background: cx.theme().colors().editor_background,
1029                local_player: cx.theme().players().local(),
1030                text: text_style,
1031                ..Default::default()
1032            },
1033        )
1034    }
1035
1036    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
1037        !self.state.read(cx).is_authenticated()
1038    }
1039}
1040
1041impl Render for ConfigurationView {
1042    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1043        let env_var_set = self.state.read(cx).api_key_from_env;
1044
1045        if self.load_credentials_task.is_some() {
1046            div().child(Label::new("Loading credentials...")).into_any()
1047        } else if self.should_render_editor(cx) {
1048            v_flex()
1049                .size_full()
1050                .on_action(cx.listener(Self::save_api_key))
1051                .child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match &self.target_agent {
1052                    ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Anthropic".into(),
1053                    ConfigurationViewTargetAgent::Other(agent) => agent.clone(),
1054                })))
1055                .child(
1056                    List::new()
1057                        .child(
1058                            InstructionListItem::new(
1059                                "Create one by visiting",
1060                                Some("Anthropic's settings"),
1061                                Some("https://console.anthropic.com/settings/keys")
1062                            )
1063                        )
1064                        .child(
1065                            InstructionListItem::text_only("Paste your API key below and hit enter to start using the agent")
1066                        )
1067                )
1068                .child(
1069                    h_flex()
1070                        .w_full()
1071                        .my_2()
1072                        .px_2()
1073                        .py_1()
1074                        .bg(cx.theme().colors().editor_background)
1075                        .border_1()
1076                        .border_color(cx.theme().colors().border)
1077                        .rounded_sm()
1078                        .child(self.render_api_key_editor(cx)),
1079                )
1080                .child(
1081                    Label::new(
1082                        format!("You can also assign the {ANTHROPIC_API_KEY_VAR} environment variable and restart Zed."),
1083                    )
1084                    .size(LabelSize::Small)
1085                    .color(Color::Muted),
1086                )
1087                .into_any()
1088        } else {
1089            h_flex()
1090                .mt_1()
1091                .p_1()
1092                .justify_between()
1093                .rounded_md()
1094                .border_1()
1095                .border_color(cx.theme().colors().border)
1096                .bg(cx.theme().colors().background)
1097                .child(
1098                    h_flex()
1099                        .gap_1()
1100                        .child(Icon::new(IconName::Check).color(Color::Success))
1101                        .child(Label::new(if env_var_set {
1102                            format!("API key set in {ANTHROPIC_API_KEY_VAR} environment variable.")
1103                        } else {
1104                            "API key configured.".to_string()
1105                        })),
1106                )
1107                .child(
1108                    Button::new("reset-key", "Reset Key")
1109                        .label_size(LabelSize::Small)
1110                        .icon(Some(IconName::Trash))
1111                        .icon_size(IconSize::Small)
1112                        .icon_position(IconPosition::Start)
1113                        .disabled(env_var_set)
1114                        .when(env_var_set, |this| {
1115                            this.tooltip(Tooltip::text(format!("To reset your API key, unset the {ANTHROPIC_API_KEY_VAR} environment variable.")))
1116                        })
1117                        .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
1118                )
1119                .into_any()
1120        }
1121    }
1122}
1123
1124#[cfg(test)]
1125mod tests {
1126    use super::*;
1127    use anthropic::AnthropicModelMode;
1128    use language_model::{LanguageModelRequestMessage, MessageContent};
1129
1130    #[test]
1131    fn test_cache_control_only_on_last_segment() {
1132        let request = LanguageModelRequest {
1133            messages: vec![LanguageModelRequestMessage {
1134                role: Role::User,
1135                content: vec![
1136                    MessageContent::Text("Some prompt".to_string()),
1137                    MessageContent::Image(language_model::LanguageModelImage::empty()),
1138                    MessageContent::Image(language_model::LanguageModelImage::empty()),
1139                    MessageContent::Image(language_model::LanguageModelImage::empty()),
1140                    MessageContent::Image(language_model::LanguageModelImage::empty()),
1141                ],
1142                cache: true,
1143            }],
1144            thread_id: None,
1145            prompt_id: None,
1146            intent: None,
1147            mode: None,
1148            stop: vec![],
1149            temperature: None,
1150            tools: vec![],
1151            tool_choice: None,
1152            thinking_allowed: true,
1153        };
1154
1155        let anthropic_request = into_anthropic(
1156            request,
1157            "claude-3-5-sonnet".to_string(),
1158            0.7,
1159            4096,
1160            AnthropicModelMode::Default,
1161        );
1162
1163        assert_eq!(anthropic_request.messages.len(), 1);
1164
1165        let message = &anthropic_request.messages[0];
1166        assert_eq!(message.content.len(), 5);
1167
1168        assert!(matches!(
1169            message.content[0],
1170            anthropic::RequestContent::Text {
1171                cache_control: None,
1172                ..
1173            }
1174        ));
1175        for i in 1..3 {
1176            assert!(matches!(
1177                message.content[i],
1178                anthropic::RequestContent::Image {
1179                    cache_control: None,
1180                    ..
1181                }
1182            ));
1183        }
1184
1185        assert!(matches!(
1186            message.content[4],
1187            anthropic::RequestContent::Image {
1188                cache_control: Some(anthropic::CacheControl {
1189                    cache_type: anthropic::CacheControlType::Ephemeral,
1190                }),
1191                ..
1192            }
1193        ));
1194    }
1195}