Detailed changes
@@ -20,6 +20,7 @@ dependencies = [
"indoc",
"itertools 0.14.0",
"language",
+ "language_model",
"markdown",
"parking_lot",
"project",
@@ -267,6 +268,8 @@ dependencies = [
"indoc",
"itertools 0.14.0",
"language",
+ "language_model",
+ "language_models",
"libc",
"log",
"nix 0.29.0",
@@ -28,6 +28,7 @@ futures.workspace = true
gpui.workspace = true
itertools.workspace = true
language.workspace = true
+language_model.workspace = true
markdown.workspace = true
parking_lot = { workspace = true, optional = true }
project.workspace = true
@@ -3,6 +3,7 @@ use agent_client_protocol::{self as acp};
use anyhow::Result;
use collections::IndexMap;
use gpui::{Entity, SharedString, Task};
+use language_model::LanguageModelProviderId;
use project::Project;
use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
use ui::{App, IconName};
@@ -80,12 +81,34 @@ pub trait AgentSessionResume {
}
#[derive(Debug)]
-pub struct AuthRequired;
+pub struct AuthRequired {
+ pub description: Option<String>,
+ pub provider_id: Option<LanguageModelProviderId>,
+}
+
+impl AuthRequired {
+ pub fn new() -> Self {
+ Self {
+ description: None,
+ provider_id: None,
+ }
+ }
+
+ pub fn with_description(mut self, description: String) -> Self {
+ self.description = Some(description);
+ self
+ }
+
+ pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self {
+ self.provider_id = Some(provider_id);
+ self
+ }
+}
impl Error for AuthRequired {}
impl fmt::Display for AuthRequired {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- write!(f, "AuthRequired")
+ write!(f, "Authentication required")
}
}
@@ -27,6 +27,8 @@ futures.workspace = true
gpui.workspace = true
indoc.workspace = true
itertools.workspace = true
+language_model.workspace = true
+language_models.workspace = true
log.workspace = true
paths.workspace = true
project.workspace = true
@@ -437,7 +437,7 @@ impl AgentConnection for AcpConnection {
let result = acp_old::InitializeParams::response_from_any(result)?;
if !result.is_authenticated {
- anyhow::bail!(AuthRequired)
+ anyhow::bail!(AuthRequired::new())
}
cx.update(|cx| {
@@ -140,7 +140,13 @@ impl AgentConnection for AcpConnection {
.await
.map_err(|err| {
if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
- anyhow!(AuthRequired)
+ let mut error = AuthRequired::new();
+
+ if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
+ error = error.with_description(err.message);
+ }
+
+ anyhow!(error)
} else {
anyhow!(err)
}
@@ -3,6 +3,7 @@ pub mod tools;
use collections::HashMap;
use context_server::listener::McpServerTool;
+use language_models::provider::anthropic::AnthropicLanguageModelProvider;
use project::Project;
use settings::SettingsStore;
use smol::process::Child;
@@ -30,7 +31,7 @@ use util::{ResultExt, debug_panic};
use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
use crate::claude::tools::ClaudeTool;
use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
-use acp_thread::{AcpThread, AgentConnection};
+use acp_thread::{AcpThread, AgentConnection, AuthRequired};
#[derive(Clone)]
pub struct ClaudeCode;
@@ -79,6 +80,36 @@ impl AgentConnection for ClaudeAgentConnection {
) -> Task<Result<Entity<AcpThread>>> {
let cwd = cwd.to_owned();
cx.spawn(async move |cx| {
+ let settings = cx.read_global(|settings: &SettingsStore, _| {
+ settings.get::<AllAgentServersSettings>(None).claude.clone()
+ })?;
+
+ let Some(command) = AgentServerCommand::resolve(
+ "claude",
+ &[],
+ Some(&util::paths::home_dir().join(".claude/local/claude")),
+ settings,
+ &project,
+ cx,
+ )
+ .await
+ else {
+ anyhow::bail!("Failed to find claude binary");
+ };
+
+ let api_key =
+ cx.update(AnthropicLanguageModelProvider::api_key)?
+ .await
+ .map_err(|err| {
+ if err.is::<language_model::AuthenticateError>() {
+ anyhow!(AuthRequired::new().with_language_model_provider(
+ language_model::ANTHROPIC_PROVIDER_ID
+ ))
+ } else {
+ anyhow!(err)
+ }
+ })?;
+
let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), cx).await?;
@@ -98,23 +129,6 @@ impl AgentConnection for ClaudeAgentConnection {
.await?;
mcp_config_file.flush().await?;
- let settings = cx.read_global(|settings: &SettingsStore, _| {
- settings.get::<AllAgentServersSettings>(None).claude.clone()
- })?;
-
- let Some(command) = AgentServerCommand::resolve(
- "claude",
- &[],
- Some(&util::paths::home_dir().join(".claude/local/claude")),
- settings,
- &project,
- cx,
- )
- .await
- else {
- anyhow::bail!("Failed to find claude binary");
- };
-
let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded();
let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
@@ -126,6 +140,7 @@ impl AgentConnection for ClaudeAgentConnection {
&command,
ClaudeSessionMode::Start,
session_id.clone(),
+ api_key,
&mcp_config_path,
&cwd,
)?;
@@ -320,6 +335,7 @@ fn spawn_claude(
command: &AgentServerCommand,
mode: ClaudeSessionMode,
session_id: acp::SessionId,
+ api_key: language_models::provider::anthropic::ApiKey,
mcp_config_path: &Path,
root_dir: &Path,
) -> Result<Child> {
@@ -355,6 +371,8 @@ fn spawn_claude(
ClaudeSessionMode::Resume => ["--resume".to_string(), session_id.to_string()],
})
.args(command.args.iter().map(|arg| arg.as_str()))
+ .envs(command.env.iter().flatten())
+ .env("ANTHROPIC_API_KEY", api_key.key)
.current_dir(root_dir)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
@@ -1,6 +1,7 @@
use acp_thread::{
AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk,
- LoadError, MentionUri, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus, UserMessageId,
+ AuthRequired, LoadError, MentionUri, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus,
+ UserMessageId,
};
use acp_thread::{AgentConnection, Plan};
use action_log::ActionLog;
@@ -18,13 +19,16 @@ use editor::{Editor, EditorMode, MultiBuffer, PathKey, SelectionEffects};
use file_icons::FileIcons;
use fs::Fs;
use gpui::{
- Action, Animation, AnimationExt, App, BorderStyle, ClickEvent, ClipboardItem, EdgesRefinement,
- Empty, Entity, FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton,
- PlatformDisplay, SharedString, Stateful, StyleRefinement, Subscription, Task, TextStyle,
- TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, Window, WindowHandle, div,
- linear_color_stop, linear_gradient, list, percentage, point, prelude::*, pulsating_between,
+ Action, Animation, AnimationExt, AnyView, App, BorderStyle, ClickEvent, ClipboardItem,
+ EdgesRefinement, Empty, Entity, FocusHandle, Focusable, Hsla, Length, ListOffset, ListState,
+ MouseButton, PlatformDisplay, SharedString, Stateful, StyleRefinement, Subscription, Task,
+ TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, Window,
+ WindowHandle, div, linear_color_stop, linear_gradient, list, percentage, point, prelude::*,
+ pulsating_between,
};
use language::Buffer;
+
+use language_model::LanguageModelRegistry;
use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle};
use project::Project;
use prompt_store::PromptId;
@@ -137,6 +141,9 @@ enum ThreadState {
LoadError(LoadError),
Unauthenticated {
connection: Rc<dyn AgentConnection>,
+ description: Option<Entity<Markdown>>,
+ configuration_view: Option<AnyView>,
+ _subscription: Option<Subscription>,
},
ServerExited {
status: ExitStatus,
@@ -267,19 +274,16 @@ impl AcpThreadView {
};
let result = match result.await {
- Err(e) => {
- let mut cx = cx.clone();
- if e.is::<acp_thread::AuthRequired>() {
- this.update(&mut cx, |this, cx| {
- this.thread_state = ThreadState::Unauthenticated { connection };
- cx.notify();
+ Err(e) => match e.downcast::<acp_thread::AuthRequired>() {
+ Ok(err) => {
+ cx.update(|window, cx| {
+ Self::handle_auth_required(this, err, agent, connection, window, cx)
})
- .ok();
+ .log_err();
return;
- } else {
- Err(e)
}
- }
+ Err(err) => Err(err),
+ },
Ok(thread) => Ok(thread),
};
@@ -345,6 +349,68 @@ impl AcpThreadView {
ThreadState::Loading { _task: load_task }
}
+ fn handle_auth_required(
+ this: WeakEntity<Self>,
+ err: AuthRequired,
+ agent: Rc<dyn AgentServer>,
+ connection: Rc<dyn AgentConnection>,
+ window: &mut Window,
+ cx: &mut App,
+ ) {
+ let agent_name = agent.name();
+ let (configuration_view, subscription) = if let Some(provider_id) = err.provider_id {
+ let registry = LanguageModelRegistry::global(cx);
+
+ let sub = window.subscribe(®istry, cx, {
+ let provider_id = provider_id.clone();
+ let this = this.clone();
+ move |_, ev, window, cx| {
+ if let language_model::Event::ProviderStateChanged(updated_provider_id) = &ev {
+ if &provider_id == updated_provider_id {
+ this.update(cx, |this, cx| {
+ this.thread_state = Self::initial_state(
+ agent.clone(),
+ this.workspace.clone(),
+ this.project.clone(),
+ window,
+ cx,
+ );
+ cx.notify();
+ })
+ .ok();
+ }
+ }
+ }
+ });
+
+ let view = registry.read(cx).provider(&provider_id).map(|provider| {
+ provider.configuration_view(
+ language_model::ConfigurationViewTargetAgent::Other(agent_name),
+ window,
+ cx,
+ )
+ });
+
+ (view, Some(sub))
+ } else {
+ (None, None)
+ };
+
+ this.update(cx, |this, cx| {
+ this.thread_state = ThreadState::Unauthenticated {
+ connection,
+ configuration_view,
+ description: err
+ .description
+ .clone()
+ .map(|desc| cx.new(|cx| Markdown::new(desc.into(), None, None, cx))),
+ _subscription: subscription,
+ };
+ cx.notify();
+ })
+ .ok();
+ }
+
fn handle_load_error(&mut self, err: anyhow::Error, cx: &mut Context<Self>) {
if let Some(load_err) = err.downcast_ref::<LoadError>() {
self.thread_state = ThreadState::LoadError(load_err.clone());
@@ -369,7 +435,7 @@ impl AcpThreadView {
ThreadState::Ready { thread, .. } => thread.read(cx).title(),
ThreadState::Loading { .. } => "Loadingβ¦".into(),
ThreadState::LoadError(_) => "Failed to load".into(),
- ThreadState::Unauthenticated { .. } => "Not authenticated".into(),
+ ThreadState::Unauthenticated { .. } => "Authentication Required".into(),
ThreadState::ServerExited { .. } => "Server exited unexpectedly".into(),
}
}
@@ -708,7 +774,7 @@ impl AcpThreadView {
window: &mut Window,
cx: &mut Context<Self>,
) {
- let ThreadState::Unauthenticated { ref connection } = self.thread_state else {
+ let ThreadState::Unauthenticated { ref connection, .. } = self.thread_state else {
return;
};
@@ -1841,19 +1907,53 @@ impl AcpThreadView {
.into_any()
}
- fn render_pending_auth_state(&self) -> AnyElement {
+ fn render_auth_required_state(
+ &self,
+ connection: &Rc<dyn AgentConnection>,
+ description: Option<&Entity<Markdown>>,
+ configuration_view: Option<&AnyView>,
+ window: &mut Window,
+ cx: &Context<Self>,
+ ) -> Div {
v_flex()
+ .p_2()
+ .gap_2()
+ .flex_1()
.items_center()
.justify_center()
- .child(self.render_error_agent_logo())
.child(
- h_flex()
- .mt_4()
- .mb_1()
+ v_flex()
+ .items_center()
.justify_center()
- .child(Headline::new("Not Authenticated").size(HeadlineSize::Medium)),
+ .child(self.render_error_agent_logo())
+ .child(
+ h_flex().mt_4().mb_1().justify_center().child(
+ Headline::new("Authentication Required").size(HeadlineSize::Medium),
+ ),
+ )
+ .into_any(),
)
- .into_any()
+ .children(description.map(|desc| {
+ div().text_ui(cx).text_center().child(
+ self.render_markdown(desc.clone(), default_markdown_style(false, window, cx)),
+ )
+ }))
+ .children(
+ configuration_view
+ .cloned()
+ .map(|view| div().px_4().w_full().max_w_128().child(view)),
+ )
+ .child(h_flex().mt_1p5().justify_center().children(
+ connection.auth_methods().into_iter().map(|method| {
+ Button::new(SharedString::from(method.id.0.clone()), method.name.clone())
+ .on_click({
+ let method_id = method.id.clone();
+ cx.listener(move |this, _, window, cx| {
+ this.authenticate(method_id.clone(), window, cx)
+ })
+ })
+ }),
+ ))
}
fn render_server_exited(&self, status: ExitStatus, _cx: &Context<Self>) -> AnyElement {
@@ -3347,26 +3447,18 @@ impl Render for AcpThreadView {
.on_action(cx.listener(Self::toggle_burn_mode))
.bg(cx.theme().colors().panel_background)
.child(match &self.thread_state {
- ThreadState::Unauthenticated { connection } => v_flex()
- .p_2()
- .flex_1()
- .items_center()
- .justify_center()
- .child(self.render_pending_auth_state())
- .child(h_flex().mt_1p5().justify_center().children(
- connection.auth_methods().into_iter().map(|method| {
- Button::new(
- SharedString::from(method.id.0.clone()),
- method.name.clone(),
- )
- .on_click({
- let method_id = method.id.clone();
- cx.listener(move |this, _, window, cx| {
- this.authenticate(method_id.clone(), window, cx)
- })
- })
- }),
- )),
+ ThreadState::Unauthenticated {
+ connection,
+ description,
+ configuration_view,
+ ..
+ } => self.render_auth_required_state(
+ &connection,
+ description.as_ref(),
+ configuration_view.as_ref(),
+ window,
+ cx,
+ ),
ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)),
ThreadState::LoadError(e) => v_flex()
.p_2()
@@ -137,7 +137,11 @@ impl AgentConfiguration {
window: &mut Window,
cx: &mut Context<Self>,
) {
- let configuration_view = provider.configuration_view(window, cx);
+ let configuration_view = provider.configuration_view(
+ language_model::ConfigurationViewTargetAgent::ZedAgent,
+ window,
+ cx,
+ );
self.configuration_views_by_provider
.insert(provider.id(), configuration_view);
}
@@ -320,7 +320,7 @@ fn init_language_model_settings(cx: &mut App) {
cx.subscribe(
&LanguageModelRegistry::global(cx),
|_, event: &language_model::Event, cx| match event {
- language_model::Event::ProviderStateChanged
+ language_model::Event::ProviderStateChanged(_)
| language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => {
update_active_language_model_from_settings(cx);
@@ -104,7 +104,7 @@ impl LanguageModelPickerDelegate {
window,
|picker, _, event, window, cx| {
match event {
- language_model::Event::ProviderStateChanged
+ language_model::Event::ProviderStateChanged(_)
| language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => {
let query = picker.query(cx);
@@ -11,7 +11,7 @@ impl ApiKeysWithProviders {
cx.subscribe(
&LanguageModelRegistry::global(cx),
|this: &mut Self, _registry, event: &language_model::Event, cx| match event {
- language_model::Event::ProviderStateChanged
+ language_model::Event::ProviderStateChanged(_)
| language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => {
this.configured_providers = Self::compute_configured_providers(cx)
@@ -25,7 +25,7 @@ impl AgentPanelOnboarding {
cx.subscribe(
&LanguageModelRegistry::global(cx),
|this: &mut Self, _registry, event: &language_model::Event, cx| match event {
- language_model::Event::ProviderStateChanged
+ language_model::Event::ProviderStateChanged(_)
| language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => {
this.configured_providers = Self::compute_available_providers(cx)
@@ -201,3 +201,9 @@ impl Drop for Subscription {
}
}
}
+
+impl std::fmt::Debug for Subscription {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("Subscription").finish()
+ }
+}
@@ -1,8 +1,8 @@
use crate::{
- AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
- LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
- LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
- LanguageModelToolChoice,
+ AuthenticateError, ConfigurationViewTargetAgent, LanguageModel, LanguageModelCompletionError,
+ LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
+ LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
+ LanguageModelRequest, LanguageModelToolChoice,
};
use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
@@ -62,7 +62,12 @@ impl LanguageModelProvider for FakeLanguageModelProvider {
Task::ready(Ok(()))
}
- fn configuration_view(&self, _window: &mut Window, _: &mut App) -> AnyView {
+ fn configuration_view(
+ &self,
+ _target_agent: ConfigurationViewTargetAgent,
+ _window: &mut Window,
+ _: &mut App,
+ ) -> AnyView {
unimplemented!()
}
@@ -634,7 +634,12 @@ pub trait LanguageModelProvider: 'static {
}
fn is_authenticated(&self, cx: &App) -> bool;
fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
- fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView;
+ fn configuration_view(
+ &self,
+ target_agent: ConfigurationViewTargetAgent,
+ window: &mut Window,
+ cx: &mut App,
+ ) -> AnyView;
fn must_accept_terms(&self, _cx: &App) -> bool {
false
}
@@ -648,6 +653,13 @@ pub trait LanguageModelProvider: 'static {
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
}
+#[derive(Default, Clone, Copy)]
+pub enum ConfigurationViewTargetAgent {
+ #[default]
+ ZedAgent,
+ Other(&'static str),
+}
+
#[derive(PartialEq, Eq)]
pub enum LanguageModelProviderTosView {
/// When there are some past interactions in the Agent Panel.
@@ -107,7 +107,7 @@ pub enum Event {
InlineAssistantModelChanged,
CommitMessageModelChanged,
ThreadSummaryModelChanged,
- ProviderStateChanged,
+ ProviderStateChanged(LanguageModelProviderId),
AddedProvider(LanguageModelProviderId),
RemovedProvider(LanguageModelProviderId),
}
@@ -148,8 +148,11 @@ impl LanguageModelRegistry {
) {
let id = provider.id();
- let subscription = provider.subscribe(cx, |_, cx| {
- cx.emit(Event::ProviderStateChanged);
+ let subscription = provider.subscribe(cx, {
+ let id = id.clone();
+ move |_, cx| {
+ cx.emit(Event::ProviderStateChanged(id.clone()));
+ }
});
if let Some(subscription) = subscription {
subscription.detach();
@@ -15,11 +15,11 @@ use gpui::{
};
use http_client::HttpClient;
use language_model::{
- AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
- LanguageModelCompletionError, LanguageModelId, LanguageModelName, LanguageModelProvider,
- LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
- LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent, MessageContent,
- RateLimiter, Role,
+ AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
+ LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelId,
+ LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
+ LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
+ LanguageModelToolResultContent, MessageContent, RateLimiter, Role,
};
use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
use schemars::JsonSchema;
@@ -153,29 +153,14 @@ impl State {
return Task::ready(Ok(()));
}
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
- let api_url = AllLanguageModelSettings::get_global(cx)
- .anthropic
- .api_url
- .clone();
+ let key = AnthropicLanguageModelProvider::api_key(cx);
cx.spawn(async move |this, cx| {
- let (api_key, from_env) = if let Ok(api_key) = std::env::var(ANTHROPIC_API_KEY_VAR) {
- (api_key, true)
- } else {
- let (_, api_key) = credentials_provider
- .read_credentials(&api_url, &cx)
- .await?
- .ok_or(AuthenticateError::CredentialsNotFound)?;
- (
- String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
- false,
- )
- };
+ let key = key.await?;
this.update(cx, |this, cx| {
- this.api_key = Some(api_key);
- this.api_key_from_env = from_env;
+ this.api_key = Some(key.key);
+ this.api_key_from_env = key.from_env;
cx.notify();
})?;
@@ -184,6 +169,11 @@ impl State {
}
}
+pub struct ApiKey {
+ pub key: String,
+ pub from_env: bool,
+}
+
impl AnthropicLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
let state = cx.new(|cx| State {
@@ -206,6 +196,33 @@ impl AnthropicLanguageModelProvider {
request_limiter: RateLimiter::new(4),
})
}
+
+ pub fn api_key(cx: &mut App) -> Task<Result<ApiKey>> {
+ let credentials_provider = <dyn CredentialsProvider>::global(cx);
+ let api_url = AllLanguageModelSettings::get_global(cx)
+ .anthropic
+ .api_url
+ .clone();
+
+ if let Ok(key) = std::env::var(ANTHROPIC_API_KEY_VAR) {
+ Task::ready(Ok(ApiKey {
+ key,
+ from_env: true,
+ }))
+ } else {
+ cx.spawn(async move |cx| {
+ let (_, api_key) = credentials_provider
+ .read_credentials(&api_url, &cx)
+ .await?
+ .ok_or(AuthenticateError::CredentialsNotFound)?;
+
+ Ok(ApiKey {
+ key: String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
+ from_env: false,
+ })
+ })
+ }
+ }
}
impl LanguageModelProviderState for AnthropicLanguageModelProvider {
@@ -299,8 +316,13 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
self.state.update(cx, |state, cx| state.authenticate(cx))
}
- fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
- cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
+ fn configuration_view(
+ &self,
+ target_agent: ConfigurationViewTargetAgent,
+ window: &mut Window,
+ cx: &mut App,
+ ) -> AnyView {
+ cx.new(|cx| ConfigurationView::new(self.state.clone(), target_agent, window, cx))
.into()
}
@@ -902,12 +924,18 @@ struct ConfigurationView {
api_key_editor: Entity<Editor>,
state: gpui::Entity<State>,
load_credentials_task: Option<Task<()>>,
+ target_agent: ConfigurationViewTargetAgent,
}
impl ConfigurationView {
const PLACEHOLDER_TEXT: &'static str = "sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
- fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
+ fn new(
+ state: gpui::Entity<State>,
+ target_agent: ConfigurationViewTargetAgent,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) -> Self {
cx.observe(&state, |_, _, cx| {
cx.notify();
})
@@ -939,6 +967,7 @@ impl ConfigurationView {
}),
state,
load_credentials_task,
+ target_agent,
}
}
@@ -1012,7 +1041,10 @@ impl Render for ConfigurationView {
v_flex()
.size_full()
.on_action(cx.listener(Self::save_api_key))
- .child(Label::new("To use Zed's agent with Anthropic, you need to add an API key. Follow these steps:"))
+ .child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match self.target_agent {
+ ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Anthropic",
+ ConfigurationViewTargetAgent::Other(agent) => agent,
+ })))
.child(
List::new()
.child(
@@ -1023,7 +1055,7 @@ impl Render for ConfigurationView {
)
)
.child(
- InstructionListItem::text_only("Paste your API key below and hit enter to start using the assistant")
+ InstructionListItem::text_only("Paste your API key below and hit enter to start using the agent")
)
)
.child(
@@ -348,7 +348,12 @@ impl LanguageModelProvider for BedrockLanguageModelProvider {
self.state.update(cx, |state, cx| state.authenticate(cx))
}
- fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
+ fn configuration_view(
+ &self,
+ _target_agent: language_model::ConfigurationViewTargetAgent,
+ window: &mut Window,
+ cx: &mut App,
+ ) -> AnyView {
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
.into()
}
@@ -391,7 +391,12 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
Task::ready(Ok(()))
}
- fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
+ fn configuration_view(
+ &self,
+ _target_agent: language_model::ConfigurationViewTargetAgent,
+ _: &mut Window,
+ cx: &mut App,
+ ) -> AnyView {
cx.new(|_| ConfigurationView::new(self.state.clone()))
.into()
}
@@ -176,7 +176,12 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
Task::ready(Err(err.into()))
}
- fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
+ fn configuration_view(
+ &self,
+ _target_agent: language_model::ConfigurationViewTargetAgent,
+ _: &mut Window,
+ cx: &mut App,
+ ) -> AnyView {
let state = self.state.clone();
cx.new(|cx| ConfigurationView::new(state, cx)).into()
}
@@ -229,7 +229,12 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider {
self.state.update(cx, |state, cx| state.authenticate(cx))
}
- fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
+ fn configuration_view(
+ &self,
+ _target_agent: language_model::ConfigurationViewTargetAgent,
+ window: &mut Window,
+ cx: &mut App,
+ ) -> AnyView {
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
.into()
}
@@ -277,7 +277,12 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
self.state.update(cx, |state, cx| state.authenticate(cx))
}
- fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
+ fn configuration_view(
+ &self,
+ _target_agent: language_model::ConfigurationViewTargetAgent,
+ window: &mut Window,
+ cx: &mut App,
+ ) -> AnyView {
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
.into()
}
@@ -226,7 +226,12 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider {
self.state.update(cx, |state, cx| state.authenticate(cx))
}
- fn configuration_view(&self, _window: &mut Window, cx: &mut App) -> AnyView {
+ fn configuration_view(
+ &self,
+ _target_agent: language_model::ConfigurationViewTargetAgent,
+ _window: &mut Window,
+ cx: &mut App,
+ ) -> AnyView {
let state = self.state.clone();
cx.new(|cx| ConfigurationView::new(state, cx)).into()
}
@@ -243,7 +243,12 @@ impl LanguageModelProvider for MistralLanguageModelProvider {
self.state.update(cx, |state, cx| state.authenticate(cx))
}
- fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
+ fn configuration_view(
+ &self,
+ _target_agent: language_model::ConfigurationViewTargetAgent,
+ window: &mut Window,
+ cx: &mut App,
+ ) -> AnyView {
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
.into()
}
@@ -255,7 +255,12 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
self.state.update(cx, |state, cx| state.authenticate(cx))
}
- fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
+ fn configuration_view(
+ &self,
+ _target_agent: language_model::ConfigurationViewTargetAgent,
+ window: &mut Window,
+ cx: &mut App,
+ ) -> AnyView {
let state = self.state.clone();
cx.new(|cx| ConfigurationView::new(state, window, cx))
.into()
@@ -233,7 +233,12 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
self.state.update(cx, |state, cx| state.authenticate(cx))
}
- fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
+ fn configuration_view(
+ &self,
+ _target_agent: language_model::ConfigurationViewTargetAgent,
+ window: &mut Window,
+ cx: &mut App,
+ ) -> AnyView {
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
.into()
}
@@ -243,7 +243,12 @@ impl LanguageModelProvider for OpenAiCompatibleLanguageModelProvider {
self.state.update(cx, |state, cx| state.authenticate(cx))
}
- fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
+ fn configuration_view(
+ &self,
+ _target_agent: language_model::ConfigurationViewTargetAgent,
+ window: &mut Window,
+ cx: &mut App,
+ ) -> AnyView {
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
.into()
}
@@ -306,7 +306,12 @@ impl LanguageModelProvider for OpenRouterLanguageModelProvider {
self.state.update(cx, |state, cx| state.authenticate(cx))
}
- fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
+ fn configuration_view(
+ &self,
+ _target_agent: language_model::ConfigurationViewTargetAgent,
+ window: &mut Window,
+ cx: &mut App,
+ ) -> AnyView {
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
.into()
}
@@ -230,7 +230,12 @@ impl LanguageModelProvider for VercelLanguageModelProvider {
self.state.update(cx, |state, cx| state.authenticate(cx))
}
- fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
+ fn configuration_view(
+ &self,
+ _target_agent: language_model::ConfigurationViewTargetAgent,
+ window: &mut Window,
+ cx: &mut App,
+ ) -> AnyView {
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
.into()
}
@@ -230,7 +230,12 @@ impl LanguageModelProvider for XAiLanguageModelProvider {
self.state.update(cx, |state, cx| state.authenticate(cx))
}
- fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
+ fn configuration_view(
+ &self,
+ _target_agent: language_model::ConfigurationViewTargetAgent,
+ window: &mut Window,
+ cx: &mut App,
+ ) -> AnyView {
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
.into()
}
@@ -329,7 +329,11 @@ impl AiConfigurationModal {
cx: &mut Context<Self>,
) -> Self {
let focus_handle = cx.focus_handle();
- let configuration_view = selected_provider.configuration_view(window, cx);
+ let configuration_view = selected_provider.configuration_view(
+ language_model::ConfigurationViewTargetAgent::ZedAgent,
+ window,
+ cx,
+ );
Self {
focus_handle,