Refine UX for assistants (#13502)

Antonio Scandurra created

<img width="1652" alt="image"
src="https://github.com/zed-industries/zed/assets/482957/376d1915-1e15-4d6c-966e-48f55f7cb249">


Release Notes:

- N/A

Change summary

crates/assistant/src/assistant.rs        |  26 +
crates/assistant/src/assistant_panel.rs  |  71 ++-
crates/assistant/src/inline_assistant.rs | 511 +++++++++++++++++--------
crates/assistant/src/prompt_library.rs   |   2 
crates/zed/src/main.rs                   |   2 
crates/zed/src/zed.rs                    |   2 
6 files changed, 415 insertions(+), 199 deletions(-)

Detailed changes

crates/assistant/src/assistant.rs πŸ”—

@@ -10,14 +10,14 @@ mod search;
 mod slash_command;
 mod streaming_diff;
 
-pub use assistant_panel::AssistantPanel;
-
+pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
 use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel};
 use assistant_slash_command::SlashCommandRegistry;
 use client::{proto, Client};
 use command_palette_hooks::CommandPaletteFilter;
 pub(crate) use completion_provider::*;
 pub(crate) use context_store::*;
+use fs::Fs;
 use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
 pub(crate) use inline_assistant::*;
 pub(crate) use model_selector::*;
@@ -264,7 +264,7 @@ impl Assistant {
     }
 }
 
-pub fn init(client: Arc<Client>, cx: &mut AppContext) {
+pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
     cx.set_global(Assistant::default());
     AssistantSettings::register(cx);
 
@@ -288,7 +288,7 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
     assistant_slash_command::init(cx);
     register_slash_commands(cx);
     assistant_panel::init(cx);
-    inline_assistant::init(client.telemetry().clone(), cx);
+    inline_assistant::init(fs.clone(), client.telemetry().clone(), cx);
     RustdocStore::init_global(cx);
 
     CommandPaletteFilter::update_global(cx, |filter, _cx| {
@@ -324,6 +324,24 @@ fn register_slash_commands(cx: &mut AppContext) {
     slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
 }
 
+pub fn humanize_token_count(count: usize) -> String {
+    match count {
+        0..=999 => count.to_string(),
+        1000..=9999 => {
+            let thousands = count / 1000;
+            let hundreds = (count % 1000 + 50) / 100;
+            if hundreds == 0 {
+                format!("{}k", thousands)
+            } else if hundreds == 10 {
+                format!("{}k", thousands + 1)
+            } else {
+                format!("{}.{}k", thousands, hundreds)
+            }
+        }
+        _ => format!("{}k", (count + 500) / 1000),
+    }
+}
+
 #[cfg(test)]
 #[ctor::ctor]
 fn init_logger() {

crates/assistant/src/assistant_panel.rs πŸ”—

@@ -1,5 +1,6 @@
 use crate::{
     assistant_settings::{AssistantDockPosition, AssistantSettings},
+    humanize_token_count,
     prompt_library::open_prompt_library,
     search::*,
     slash_command::{
@@ -89,6 +90,10 @@ pub fn init(cx: &mut AppContext) {
     .detach();
 }
 
+pub enum AssistantPanelEvent {
+    ContextEdited,
+}
+
 pub struct AssistantPanel {
     workspace: WeakView<Workspace>,
     width: Option<Pixels>,
@@ -360,11 +365,11 @@ impl AssistantPanel {
             return;
         }
 
-        let Some(assistant) = workspace.panel::<AssistantPanel>(cx) else {
+        let Some(assistant_panel) = workspace.panel::<AssistantPanel>(cx) else {
             return;
         };
 
-        let context_editor = assistant
+        let context_editor = assistant_panel
             .read(cx)
             .active_context_editor()
             .and_then(|editor| {
@@ -391,25 +396,37 @@ impl AssistantPanel {
             return;
         };
 
-        if assistant.update(cx, |assistant, cx| assistant.is_authenticated(cx)) {
+        if assistant_panel.update(cx, |panel, cx| panel.is_authenticated(cx)) {
             InlineAssistant::update_global(cx, |assistant, cx| {
                 assistant.assist(
                     &active_editor,
                     Some(cx.view().downgrade()),
-                    include_context,
+                    include_context.then_some(&assistant_panel),
                     cx,
                 )
             })
         } else {
-            let assistant = assistant.downgrade();
+            let assistant_panel = assistant_panel.downgrade();
             cx.spawn(|workspace, mut cx| async move {
-                assistant
+                assistant_panel
                     .update(&mut cx, |assistant, cx| assistant.authenticate(cx))?
                     .await?;
-                if assistant.update(&mut cx, |assistant, cx| assistant.is_authenticated(cx))? {
+                if assistant_panel
+                    .update(&mut cx, |assistant, cx| assistant.is_authenticated(cx))?
+                {
                     cx.update(|cx| {
+                        let assistant_panel = if include_context {
+                            assistant_panel.upgrade()
+                        } else {
+                            None
+                        };
                         InlineAssistant::update_global(cx, |assistant, cx| {
-                            assistant.assist(&active_editor, Some(workspace), include_context, cx)
+                            assistant.assist(
+                                &active_editor,
+                                Some(workspace),
+                                assistant_panel.as_ref(),
+                                cx,
+                            )
                         })
                     })?
                 } else {
@@ -460,7 +477,7 @@ impl AssistantPanel {
             _subscriptions: subscriptions,
         });
         self.show_saved_contexts = false;
-
+        cx.emit(AssistantPanelEvent::ContextEdited);
         cx.notify();
     }
 
@@ -472,6 +489,7 @@ impl AssistantPanel {
     ) {
         match event {
             ContextEditorEvent::TabContentChanged => cx.notify(),
+            ContextEditorEvent::Edited => cx.emit(AssistantPanelEvent::ContextEdited),
         }
     }
 
@@ -863,18 +881,33 @@ impl AssistantPanel {
         context: &Model<Context>,
         cx: &mut ViewContext<Self>,
     ) -> Option<impl IntoElement> {
-        let remaining_tokens = context.read(cx).remaining_tokens(cx)?;
-        let remaining_tokens_color = if remaining_tokens <= 0 {
+        let model = CompletionProvider::global(cx).model();
+        let token_count = context.read(cx).token_count()?;
+        let max_token_count = model.max_token_count();
+
+        let remaining_tokens = max_token_count as isize - token_count as isize;
+        let token_count_color = if remaining_tokens <= 0 {
             Color::Error
-        } else if remaining_tokens <= 500 {
+        } else if token_count as f32 / max_token_count as f32 >= 0.8 {
             Color::Warning
         } else {
             Color::Muted
         };
+
         Some(
-            Label::new(remaining_tokens.to_string())
-                .size(LabelSize::Small)
-                .color(remaining_tokens_color),
+            h_flex()
+                .gap_0p5()
+                .child(
+                    Label::new(humanize_token_count(token_count))
+                        .size(LabelSize::Small)
+                        .color(token_count_color),
+                )
+                .child(Label::new("/").size(LabelSize::Small).color(Color::Muted))
+                .child(
+                    Label::new(humanize_token_count(max_token_count))
+                        .size(LabelSize::Small)
+                        .color(Color::Muted),
+                ),
         )
     }
 }
@@ -978,6 +1011,7 @@ impl Panel for AssistantPanel {
 }
 
 impl EventEmitter<PanelEvent> for AssistantPanel {}
+impl EventEmitter<AssistantPanelEvent> for AssistantPanel {}
 
 impl FocusableView for AssistantPanel {
     fn focus_handle(&self, _cx: &AppContext) -> FocusHandle {
@@ -1538,11 +1572,6 @@ impl Context {
         }
     }
 
-    fn remaining_tokens(&self, cx: &AppContext) -> Option<isize> {
-        let model = CompletionProvider::global(cx).model();
-        Some(model.max_token_count() as isize - self.token_count? as isize)
-    }
-
     fn completion_provider_changed(&mut self, cx: &mut ModelContext<Self>) {
         self.count_remaining_tokens(cx);
     }
@@ -2183,6 +2212,7 @@ struct PendingCompletion {
 }
 
 enum ContextEditorEvent {
+    Edited,
     TabContentChanged,
 }
 
@@ -2775,6 +2805,7 @@ impl ContextEditor {
             EditorEvent::SelectionsChanged { .. } => {
                 self.scroll_position = self.cursor_scroll_position(cx);
             }
+            EditorEvent::BufferEdited => cx.emit(ContextEditorEvent::Edited),
             _ => {}
         }
     }

crates/assistant/src/inline_assistant.rs πŸ”—

@@ -1,8 +1,9 @@
 use crate::{
-    prompts::generate_content_prompt, AssistantPanel, CompletionProvider, Hunk,
-    LanguageModelRequest, LanguageModelRequestMessage, Role, StreamingDiff,
+    assistant_settings::AssistantSettings, humanize_token_count, prompts::generate_content_prompt,
+    AssistantPanel, AssistantPanelEvent, CompletionProvider, Hunk, LanguageModelRequest,
+    LanguageModelRequestMessage, Role, StreamingDiff,
 };
-use anyhow::{Context as _, Result};
+use anyhow::{anyhow, Context as _, Result};
 use client::telemetry::Telemetry;
 use collections::{hash_map, HashMap, HashSet, VecDeque};
 use editor::{
@@ -14,6 +15,7 @@ use editor::{
     Anchor, AnchorRangeExt, Editor, EditorElement, EditorEvent, EditorMode, EditorStyle,
     ExcerptRange, GutterDimensions, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
 };
+use fs::Fs;
 use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
 use gpui::{
     point, AppContext, EventEmitter, FocusHandle, FocusableView, FontStyle, FontWeight, Global,
@@ -24,7 +26,7 @@ use language::{Buffer, Point, Selection, TransactionId};
 use multi_buffer::MultiBufferRow;
 use parking_lot::Mutex;
 use rope::Rope;
-use settings::Settings;
+use settings::{update_settings_file, Settings};
 use similar::TextDiff;
 use std::{
     cmp, mem,
@@ -32,15 +34,15 @@ use std::{
     pin::Pin,
     sync::Arc,
     task::{self, Poll},
-    time::Instant,
+    time::{Duration, Instant},
 };
 use theme::ThemeSettings;
-use ui::{prelude::*, Tooltip};
+use ui::{prelude::*, ContextMenu, PopoverMenu, Tooltip};
 use util::RangeExt;
 use workspace::{notifications::NotificationId, Toast, Workspace};
 
-pub fn init(telemetry: Arc<Telemetry>, cx: &mut AppContext) {
-    cx.set_global(InlineAssistant::new(telemetry));
+pub fn init(fs: Arc<dyn Fs>, telemetry: Arc<Telemetry>, cx: &mut AppContext) {
+    cx.set_global(InlineAssistant::new(fs, telemetry));
 }
 
 const PROMPT_HISTORY_MAX_LEN: usize = 20;
@@ -53,12 +55,13 @@ pub struct InlineAssistant {
     assist_groups: HashMap<InlineAssistGroupId, InlineAssistGroup>,
     prompt_history: VecDeque<String>,
     telemetry: Option<Arc<Telemetry>>,
+    fs: Arc<dyn Fs>,
 }
 
 impl Global for InlineAssistant {}
 
 impl InlineAssistant {
-    pub fn new(telemetry: Arc<Telemetry>) -> Self {
+    pub fn new(fs: Arc<dyn Fs>, telemetry: Arc<Telemetry>) -> Self {
         Self {
             next_assist_id: InlineAssistId::default(),
             next_assist_group_id: InlineAssistGroupId::default(),
@@ -67,6 +70,7 @@ impl InlineAssistant {
             assist_groups: HashMap::default(),
             prompt_history: VecDeque::default(),
             telemetry: Some(telemetry),
+            fs,
         }
     }
 
@@ -74,7 +78,7 @@ impl InlineAssistant {
         &mut self,
         editor: &View<Editor>,
         workspace: Option<WeakView<Workspace>>,
-        include_context: bool,
+        assistant_panel: Option<&View<AssistantPanel>>,
         cx: &mut WindowContext,
     ) {
         let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
@@ -151,7 +155,10 @@ impl InlineAssistant {
                     self.prompt_history.clone(),
                     prompt_buffer.clone(),
                     codegen.clone(),
+                    editor,
+                    assistant_panel,
                     workspace.clone(),
+                    self.fs.clone(),
                     cx,
                 )
             });
@@ -208,7 +215,7 @@ impl InlineAssistant {
                 InlineAssist::new(
                     assist_id,
                     assist_group_id,
-                    include_context,
+                    assistant_panel.is_some(),
                     editor,
                     &prompt_editor,
                     block_ids[0],
@@ -706,8 +713,6 @@ impl InlineAssistant {
             return;
         }
 
-        assist.codegen.update(cx, |codegen, cx| codegen.undo(cx));
-
         let Some(user_prompt) = assist
             .decorations
             .as_ref()
@@ -716,115 +721,138 @@ impl InlineAssistant {
             return;
         };
 
-        let context = if assist.include_context {
-            assist.workspace.as_ref().and_then(|workspace| {
-                let workspace = workspace.upgrade()?.read(cx);
-                let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
-                assistant_panel.read(cx).active_context(cx)
-            })
-        } else {
-            None
-        };
-
-        let editor = if let Some(editor) = assist.editor.upgrade() {
-            editor
-        } else {
-            return;
-        };
-
-        let project_name = assist.workspace.as_ref().and_then(|workspace| {
-            let workspace = workspace.upgrade()?;
-            Some(
-                workspace
-                    .read(cx)
-                    .project()
-                    .read(cx)
-                    .worktree_root_names(cx)
-                    .collect::<Vec<&str>>()
-                    .join("/"),
-            )
-        });
-
         self.prompt_history.retain(|prompt| *prompt != user_prompt);
         self.prompt_history.push_back(user_prompt.clone());
         if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
             self.prompt_history.pop_front();
         }
 
+        assist.codegen.update(cx, |codegen, cx| codegen.undo(cx));
         let codegen = assist.codegen.clone();
-        let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
-        let range = codegen.read(cx).range.clone();
-        let start = snapshot.point_to_buffer_offset(range.start);
-        let end = snapshot.point_to_buffer_offset(range.end);
-        let (buffer, range) = if let Some((start, end)) = start.zip(end) {
-            let (start_buffer, start_buffer_offset) = start;
-            let (end_buffer, end_buffer_offset) = end;
-            if start_buffer.remote_id() == end_buffer.remote_id() {
-                (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
-            } else {
-                self.finish_assist(assist_id, false, cx);
-                return;
-            }
-        } else {
-            self.finish_assist(assist_id, false, cx);
-            return;
-        };
+        let request = self.request_for_inline_assist(assist_id, cx);
 
-        let language = buffer.language_at(range.start);
-        let language_name = if let Some(language) = language.as_ref() {
-            if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
-                None
-            } else {
-                Some(language.name())
-            }
-        } else {
-            None
-        };
+        cx.spawn(|mut cx| async move {
+            let request = request.await?;
+            codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx))?;
+            anyhow::Ok(())
+        })
+        .detach_and_log_err(cx);
+    }
 
-        // Higher Temperature increases the randomness of model outputs.
-        // If Markdown or No Language is Known, increase the randomness for more creative output
-        // If Code, decrease temperature to get more deterministic outputs
-        let temperature = if let Some(language) = language_name.clone() {
-            if language.as_ref() == "Markdown" {
-                1.0
+    fn request_for_inline_assist(
+        &self,
+        assist_id: InlineAssistId,
+        cx: &mut WindowContext,
+    ) -> Task<Result<LanguageModelRequest>> {
+        cx.spawn(|mut cx| async move {
+            let (user_prompt, context_request, project_name, buffer, range, model) = cx
+                .read_global(|this: &InlineAssistant, cx: &WindowContext| {
+                    let assist = this.assists.get(&assist_id).context("invalid assist")?;
+                    let decorations = assist.decorations.as_ref().context("invalid assist")?;
+                    let editor = assist.editor.upgrade().context("invalid assist")?;
+                    let user_prompt = decorations.prompt_editor.read(cx).prompt(cx);
+                    let context_request = if assist.include_context {
+                        assist.workspace.as_ref().and_then(|workspace| {
+                            let workspace = workspace.upgrade()?.read(cx);
+                            let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
+                            Some(
+                                assistant_panel
+                                    .read(cx)
+                                    .active_context(cx)?
+                                    .read(cx)
+                                    .to_completion_request(cx),
+                            )
+                        })
+                    } else {
+                        None
+                    };
+                    let project_name = assist.workspace.as_ref().and_then(|workspace| {
+                        let workspace = workspace.upgrade()?;
+                        Some(
+                            workspace
+                                .read(cx)
+                                .project()
+                                .read(cx)
+                                .worktree_root_names(cx)
+                                .collect::<Vec<&str>>()
+                                .join("/"),
+                        )
+                    });
+                    let buffer = editor.read(cx).buffer().read(cx).snapshot(cx);
+                    let range = assist.codegen.read(cx).range.clone();
+                    let model = CompletionProvider::global(cx).model();
+                    anyhow::Ok((
+                        user_prompt,
+                        context_request,
+                        project_name,
+                        buffer,
+                        range,
+                        model,
+                    ))
+                })??;
+
+            let language = buffer.language_at(range.start);
+            let language_name = if let Some(language) = language.as_ref() {
+                if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
+                    None
+                } else {
+                    Some(language.name())
+                }
             } else {
-                0.5
-            }
-        } else {
-            1.0
-        };
+                None
+            };
 
-        let prompt = cx.background_executor().spawn(async move {
-            let language_name = language_name.as_deref();
-            generate_content_prompt(user_prompt, language_name, buffer, range, project_name)
-        });
+            // Higher Temperature increases the randomness of model outputs.
+            // If Markdown or No Language is Known, increase the randomness for more creative output
+            // If Code, decrease temperature to get more deterministic outputs
+            let temperature = if let Some(language) = language_name.clone() {
+                if language.as_ref() == "Markdown" {
+                    1.0
+                } else {
+                    0.5
+                }
+            } else {
+                1.0
+            };
 
-        let mut messages = Vec::new();
-        if let Some(context) = context {
-            let request = context.read(cx).to_completion_request(cx);
-            messages = request.messages;
-        }
-        let model = CompletionProvider::global(cx).model();
+            let prompt = cx
+                .background_executor()
+                .spawn(async move {
+                    let language_name = language_name.as_deref();
+                    let start = buffer.point_to_buffer_offset(range.start);
+                    let end = buffer.point_to_buffer_offset(range.end);
+                    let (buffer, range) = if let Some((start, end)) = start.zip(end) {
+                        let (start_buffer, start_buffer_offset) = start;
+                        let (end_buffer, end_buffer_offset) = end;
+                        if start_buffer.remote_id() == end_buffer.remote_id() {
+                            (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
+                        } else {
+                            return Err(anyhow!("invalid transformation range"));
+                        }
+                    } else {
+                        return Err(anyhow!("invalid transformation range"));
+                    };
+                    generate_content_prompt(user_prompt, language_name, buffer, range, project_name)
+                })
+                .await?;
 
-        cx.spawn(|mut cx| async move {
-            let prompt = prompt.await?;
+            let mut messages = Vec::new();
+            if let Some(context_request) = context_request {
+                messages = context_request.messages;
+            }
 
             messages.push(LanguageModelRequestMessage {
                 role: Role::User,
                 content: prompt,
             });
 
-            let request = LanguageModelRequest {
+            Ok(LanguageModelRequest {
                 model,
                 messages,
                 stop: vec!["|END|>".to_string()],
                 temperature,
-            };
-
-            codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx))?;
-            anyhow::Ok(())
+            })
         })
-        .detach_and_log_err(cx);
     }
 
     fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
@@ -1142,6 +1170,7 @@ enum PromptEditorEvent {
 
 struct PromptEditor {
     id: InlineAssistId,
+    fs: Arc<dyn Fs>,
     height_in_lines: u8,
     editor: View<Editor>,
     edited_since_done: bool,
@@ -1150,9 +1179,12 @@ struct PromptEditor {
     prompt_history_ix: Option<usize>,
     pending_prompt: String,
     codegen: Model<Codegen>,
-    workspace: Option<WeakView<Workspace>>,
     _codegen_subscription: Subscription,
     editor_subscriptions: Vec<Subscription>,
+    pending_token_count: Task<Result<()>>,
+    token_count: Option<usize>,
+    _token_count_subscriptions: Vec<Subscription>,
+    workspace: Option<WeakView<Workspace>>,
 }
 
 impl EventEmitter<PromptEditorEvent> for PromptEditor {}
@@ -1160,6 +1192,7 @@ impl EventEmitter<PromptEditorEvent> for PromptEditor {}
 impl Render for PromptEditor {
     fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
         let gutter_dimensions = *self.gutter_dimensions.lock();
+        let fs = self.fs.clone();
 
         let buttons = match &self.codegen.read(cx).status {
             CodegenStatus::Idle => {
@@ -1245,85 +1278,100 @@ impl Render for PromptEditor {
             }
         };
 
-        v_flex().h_full().w_full().justify_end().child(
-            h_flex()
-                .bg(cx.theme().colors().editor_background)
-                .border_y_1()
-                .border_color(cx.theme().status().info_border)
-                .py_1p5()
-                .w_full()
-                .on_action(cx.listener(Self::confirm))
-                .on_action(cx.listener(Self::cancel))
-                .on_action(cx.listener(Self::move_up))
-                .on_action(cx.listener(Self::move_down))
-                .child(
-                    h_flex()
-                        .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
-                        // .pr(gutter_dimensions.fold_area_width())
-                        .justify_center()
-                        .gap_2()
-                        .children(self.workspace.clone().map(|workspace| {
-                            IconButton::new("context", IconName::Context)
-                                .size(ButtonSize::None)
-                                .icon_size(IconSize::XSmall)
-                                .icon_color(Color::Muted)
-                                .on_click({
-                                    let workspace = workspace.clone();
-                                    cx.listener(move |_, _, cx| {
-                                        workspace
-                                            .update(cx, |workspace, cx| {
-                                                workspace.focus_panel::<AssistantPanel>(cx);
-                                            })
-                                            .ok();
-                                    })
+        h_flex()
+            .bg(cx.theme().colors().editor_background)
+            .border_y_1()
+            .border_color(cx.theme().status().info_border)
+            .py_1p5()
+            .h_full()
+            .w_full()
+            .on_action(cx.listener(Self::confirm))
+            .on_action(cx.listener(Self::cancel))
+            .on_action(cx.listener(Self::move_up))
+            .on_action(cx.listener(Self::move_down))
+            .child(
+                h_flex()
+                    .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
+                    .justify_center()
+                    .gap_2()
+                    .child(
+                        PopoverMenu::new("model-switcher")
+                            .menu(move |cx| {
+                                ContextMenu::build(cx, |mut menu, cx| {
+                                    for model in CompletionProvider::global(cx).available_models() {
+                                        menu = menu.custom_entry(
+                                            {
+                                                let model = model.clone();
+                                                move |_| {
+                                                    Label::new(model.display_name())
+                                                        .into_any_element()
+                                                }
+                                            },
+                                            {
+                                                let fs = fs.clone();
+                                                let model = model.clone();
+                                                move |cx| {
+                                                    let model = model.clone();
+                                                    update_settings_file::<AssistantSettings>(
+                                                        fs.clone(),
+                                                        cx,
+                                                        move |settings| settings.set_model(model),
+                                                    );
+                                                }
+                                            },
+                                        );
+                                    }
+                                    menu
                                 })
-                                .tooltip(move |cx| {
-                                    let token_count = workspace.upgrade().and_then(|workspace| {
-                                        let panel =
-                                            workspace.read(cx).panel::<AssistantPanel>(cx)?;
-                                        let context = panel.read(cx).active_context(cx)?;
-                                        context.read(cx).token_count()
-                                    });
-                                    if let Some(token_count) = token_count {
+                                .into()
+                            })
+                            .trigger(
+                                IconButton::new("context", IconName::Settings)
+                                    .size(ButtonSize::None)
+                                    .icon_size(IconSize::Small)
+                                    .icon_color(Color::Muted)
+                                    .tooltip(move |cx| {
                                         Tooltip::with_meta(
                                             format!(
-                                                "{} Additional Context Tokens from Assistant",
-                                                token_count
+                                                "Using {}",
+                                                CompletionProvider::global(cx)
+                                                    .model()
+                                                    .display_name()
                                             ),
-                                            Some(&crate::ToggleFocus),
-                                            "Click to open…",
+                                            None,
+                                            "Click to Change Model",
                                             cx,
                                         )
-                                    } else {
-                                        Tooltip::for_action(
-                                            "Toggle Assistant Panel",
-                                            &crate::ToggleFocus,
-                                            cx,
-                                        )
-                                    }
-                                })
-                        }))
-                        .children(
-                            if let CodegenStatus::Error(error) = &self.codegen.read(cx).status {
-                                let error_message = SharedString::from(error.to_string());
-                                Some(
-                                    div()
-                                        .id("error")
-                                        .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
-                                        .child(
-                                            Icon::new(IconName::XCircle)
-                                                .size(IconSize::Small)
-                                                .color(Color::Error),
-                                        ),
-                                )
-                            } else {
-                                None
-                            },
-                        ),
-                )
-                .child(div().flex_1().child(self.render_prompt_editor(cx)))
-                .child(h_flex().gap_2().pr_4().children(buttons)),
-        )
+                                    }),
+                            )
+                            .anchor(gpui::AnchorCorner::BottomRight),
+                    )
+                    .children(
+                        if let CodegenStatus::Error(error) = &self.codegen.read(cx).status {
+                            let error_message = SharedString::from(error.to_string());
+                            Some(
+                                div()
+                                    .id("error")
+                                    .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
+                                    .child(
+                                        Icon::new(IconName::XCircle)
+                                            .size(IconSize::Small)
+                                            .color(Color::Error),
+                                    ),
+                            )
+                        } else {
+                            None
+                        },
+                    ),
+            )
+            .child(div().flex_1().child(self.render_prompt_editor(cx)))
+            .child(
+                h_flex()
+                    .gap_2()
+                    .pr_4()
+                    .children(self.render_token_count(cx))
+                    .children(buttons),
+            )
     }
 }
 
@@ -1336,13 +1384,17 @@ impl FocusableView for PromptEditor {
 impl PromptEditor {
     const MAX_LINES: u8 = 8;
 
+    #[allow(clippy::too_many_arguments)]
     fn new(
         id: InlineAssistId,
         gutter_dimensions: Arc<Mutex<GutterDimensions>>,
         prompt_history: VecDeque<String>,
         prompt_buffer: Model<MultiBuffer>,
         codegen: Model<Codegen>,
+        parent_editor: &View<Editor>,
+        assistant_panel: Option<&View<AssistantPanel>>,
         workspace: Option<WeakView<Workspace>>,
+        fs: Arc<dyn Fs>,
         cx: &mut ViewContext<Self>,
     ) -> Self {
         let prompt_editor = cx.new_view(|cx| {
@@ -1363,6 +1415,15 @@ impl PromptEditor {
             editor.set_placeholder_text("Add a prompt…", cx);
             editor
         });
+
+        let mut token_count_subscriptions = Vec::new();
+        token_count_subscriptions
+            .push(cx.subscribe(parent_editor, Self::handle_parent_editor_event));
+        if let Some(assistant_panel) = assistant_panel {
+            token_count_subscriptions
+                .push(cx.subscribe(assistant_panel, Self::handle_assistant_panel_event));
+        }
+
         let mut this = Self {
             id,
             height_in_lines: 1,
@@ -1375,9 +1436,14 @@ impl PromptEditor {
             _codegen_subscription: cx.observe(&codegen, Self::handle_codegen_changed),
             editor_subscriptions: Vec::new(),
             codegen,
+            fs,
+            pending_token_count: Task::ready(Ok(())),
+            token_count: None,
+            _token_count_subscriptions: token_count_subscriptions,
             workspace,
         };
         this.count_lines(cx);
+        this.count_tokens(cx);
         this.subscribe_to_editor(cx);
         this
     }
@@ -1436,6 +1502,47 @@ impl PromptEditor {
         }
     }
 
+    fn handle_parent_editor_event(
+        &mut self,
+        _: View<Editor>,
+        event: &EditorEvent,
+        cx: &mut ViewContext<Self>,
+    ) {
+        if let EditorEvent::BufferEdited { .. } = event {
+            self.count_tokens(cx);
+        }
+    }
+
+    fn handle_assistant_panel_event(
+        &mut self,
+        _: View<AssistantPanel>,
+        event: &AssistantPanelEvent,
+        cx: &mut ViewContext<Self>,
+    ) {
+        let AssistantPanelEvent::ContextEdited { .. } = event;
+        self.count_tokens(cx);
+    }
+
+    fn count_tokens(&mut self, cx: &mut ViewContext<Self>) {
+        let assist_id = self.id;
+        self.pending_token_count = cx.spawn(|this, mut cx| async move {
+            cx.background_executor().timer(Duration::from_secs(1)).await;
+            let request = cx
+                .update_global(|inline_assistant: &mut InlineAssistant, cx| {
+                    inline_assistant.request_for_inline_assist(assist_id, cx)
+                })?
+                .await?;
+
+            let token_count = cx
+                .update(|cx| CompletionProvider::global(cx).count_tokens(request, cx))?
+                .await?;
+            this.update(&mut cx, |this, cx| {
+                this.token_count = Some(token_count);
+                cx.notify();
+            })
+        })
+    }
+
     fn handle_prompt_editor_changed(&mut self, _: View<Editor>, cx: &mut ViewContext<Self>) {
         self.count_lines(cx);
     }
@@ -1460,6 +1567,9 @@ impl PromptEditor {
                 self.edited_since_done = true;
                 cx.notify();
             }
+            EditorEvent::BufferEdited => {
+                self.count_tokens(cx);
+            }
             _ => {}
         }
     }
@@ -1551,6 +1661,63 @@ impl PromptEditor {
         }
     }
 
+    fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
+        let model = CompletionProvider::global(cx).model();
+        let token_count = self.token_count?;
+        let max_token_count = model.max_token_count();
+
+        let remaining_tokens = max_token_count as isize - token_count as isize;
+        let token_count_color = if remaining_tokens <= 0 {
+            Color::Error
+        } else if token_count as f32 / max_token_count as f32 >= 0.8 {
+            Color::Warning
+        } else {
+            Color::Muted
+        };
+
+        let mut token_count = h_flex()
+            .id("token_count")
+            .gap_0p5()
+            .child(
+                Label::new(humanize_token_count(token_count))
+                    .size(LabelSize::Small)
+                    .color(token_count_color),
+            )
+            .child(Label::new("/").size(LabelSize::Small).color(Color::Muted))
+            .child(
+                Label::new(humanize_token_count(max_token_count))
+                    .size(LabelSize::Small)
+                    .color(Color::Muted),
+            );
+        if let Some(workspace) = self.workspace.clone() {
+            token_count = token_count
+                .tooltip(|cx| {
+                    Tooltip::with_meta(
+                        "Tokens Used by Inline Assistant",
+                        None,
+                        "Click to Open Assistant Panel",
+                        cx,
+                    )
+                })
+                .cursor_pointer()
+                .on_mouse_down(gpui::MouseButton::Left, |_, cx| cx.stop_propagation())
+                .on_click(move |_, cx| {
+                    cx.stop_propagation();
+                    workspace
+                        .update(cx, |workspace, cx| {
+                            workspace.focus_panel::<AssistantPanel>(cx)
+                        })
+                        .ok();
+                });
+        } else {
+            token_count = token_count
+                .cursor_default()
+                .tooltip(|cx| Tooltip::text("Tokens Used by Inline Assistant", cx));
+        }
+
+        Some(token_count)
+    }
+
     fn render_prompt_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
         let settings = ThemeSettings::get_global(cx);
         let text_style = TextStyle {

crates/assistant/src/prompt_library.rs πŸ”—

@@ -569,7 +569,7 @@ impl PromptLibrary {
         let provider = CompletionProvider::global(cx);
         if provider.is_authenticated() {
             InlineAssistant::update_global(cx, |assistant, cx| {
-                assistant.assist(&prompt_editor, None, false, cx)
+                assistant.assist(&prompt_editor, None, None, cx)
             })
         } else {
             for window in cx.windows() {

crates/zed/src/main.rs πŸ”—

@@ -219,7 +219,7 @@ fn init_ui(app_state: Arc<AppState>, cx: &mut AppContext) -> Result<()> {
 
     inline_completion_registry::init(app_state.client.telemetry().clone(), cx);
 
-    assistant::init(app_state.client.clone(), cx);
+    assistant::init(app_state.fs.clone(), app_state.client.clone(), cx);
 
     repl::init(app_state.fs.clone(), cx);
 

crates/zed/src/zed.rs πŸ”—

@@ -3181,7 +3181,7 @@ mod tests {
             project_panel::init((), cx);
             outline_panel::init((), cx);
             terminal_view::init(cx);
-            assistant::init(app_state.client.clone(), cx);
+            assistant::init(app_state.fs.clone(), app_state.client.clone(), cx);
             tasks_ui::init(cx);
             initialize_workspace(app_state.clone(), cx);
             app_state