Show the current model and allow clicking on it to change it

Antonio Scandurra created

Change summary

crates/ai/src/assistant.rs        | 85 ++++++++++++++++++++++----------
crates/theme/src/theme.rs         |  2 
styles/src/styleTree/assistant.ts | 26 ++++++++-
3 files changed, 82 insertions(+), 31 deletions(-)

Detailed changes

crates/ai/src/assistant.rs 🔗

@@ -9,15 +9,17 @@ use editor::{Anchor, Editor, ExcerptId, ExcerptRange, MultiBuffer};
 use fs::Fs;
 use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
 use gpui::{
-    actions, elements::*, executor::Background, Action, AppContext, AsyncAppContext, Entity,
-    ModelContext, ModelHandle, Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle,
-    WindowContext,
+    actions,
+    elements::*,
+    executor::Background,
+    platform::{CursorStyle, MouseButton},
+    Action, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Subscription, Task,
+    View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
 };
 use isahc::{http::StatusCode, Request, RequestExt};
 use language::{language_settings::SoftWrap, Buffer, LanguageRegistry};
 use settings::SettingsStore;
 use std::{cell::RefCell, io, rc::Rc, sync::Arc, time::Duration};
-use tiktoken_rs::model::get_context_size;
 use util::{post_inc, ResultExt, TryFutureExt};
 use workspace::{
     dock::{DockPosition, Panel},
@@ -430,7 +432,7 @@ impl Assistant {
             pending_completions: Default::default(),
             languages: language_registry,
             token_count: None,
-            max_token_count: get_context_size(model),
+            max_token_count: tiktoken_rs::model::get_context_size(model),
             pending_token_count: Task::ready(None),
             model: model.into(),
             _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
@@ -483,6 +485,7 @@ impl Assistant {
                     .await?;
 
                 this.update(&mut cx, |this, cx| {
+                    this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
                     this.token_count = Some(token_count);
                     cx.notify()
                 });
@@ -496,6 +499,12 @@ impl Assistant {
         Some(self.max_token_count as isize - self.token_count? as isize)
     }
 
+    fn set_model(&mut self, model: String, cx: &mut ModelContext<Self>) {
+        self.model = model;
+        self.count_remaining_tokens(cx);
+        cx.notify();
+    }
+
     fn assist(&mut self, cx: &mut ModelContext<Self>) {
         let messages = self
             .messages
@@ -825,6 +834,16 @@ impl AssistantEditor {
             });
         }
     }
+
+    fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
+        self.assistant.update(cx, |assistant, cx| {
+            let new_model = match assistant.model.as_str() {
+                "gpt-4" => "gpt-3.5-turbo",
+                _ => "gpt-4",
+            };
+            assistant.set_model(new_model.into(), cx);
+        });
+    }
 }
 
 impl Entity for AssistantEditor {
@@ -837,27 +856,23 @@ impl View for AssistantEditor {
     }
 
     fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
+        enum Model {}
         let theme = &theme::current(cx).assistant;
-        let remaining_tokens = self
-            .assistant
-            .read(cx)
-            .remaining_tokens()
-            .map(|remaining_tokens| {
-                let remaining_tokens_style = if remaining_tokens <= 0 {
-                    &theme.no_remaining_tokens
-                } else {
-                    &theme.remaining_tokens
-                };
-                Label::new(
-                    remaining_tokens.to_string(),
-                    remaining_tokens_style.text.clone(),
-                )
-                .contained()
-                .with_style(remaining_tokens_style.container)
-                .aligned()
-                .top()
-                .right()
-            });
+        let assistant = &self.assistant.read(cx);
+        let model = assistant.model.clone();
+        let remaining_tokens = assistant.remaining_tokens().map(|remaining_tokens| {
+            let remaining_tokens_style = if remaining_tokens <= 0 {
+                &theme.no_remaining_tokens
+            } else {
+                &theme.remaining_tokens
+            };
+            Label::new(
+                remaining_tokens.to_string(),
+                remaining_tokens_style.text.clone(),
+            )
+            .contained()
+            .with_style(remaining_tokens_style.container)
+        });
 
         Stack::new()
             .with_child(
@@ -865,7 +880,25 @@ impl View for AssistantEditor {
                     .contained()
                     .with_style(theme.container),
             )
-            .with_children(remaining_tokens)
+            .with_child(
+                Flex::row()
+                    .with_child(
+                        MouseEventHandler::<Model, _>::new(0, cx, |state, _| {
+                            let style = theme.model.style_for(state, false);
+                            Label::new(model, style.text.clone())
+                                .contained()
+                                .with_style(style.container)
+                        })
+                        .with_cursor_style(CursorStyle::PointingHand)
+                        .on_click(MouseButton::Left, |_, this, cx| this.cycle_model(cx)),
+                    )
+                    .with_children(remaining_tokens)
+                    .contained()
+                    .with_style(theme.model_info_container)
+                    .aligned()
+                    .top()
+                    .right(),
+            )
             .into_any()
     }
 

crates/theme/src/theme.rs 🔗

@@ -976,6 +976,8 @@ pub struct AssistantStyle {
     pub sent_at: ContainedText,
     pub user_sender: ContainedText,
     pub assistant_sender: ContainedText,
+    pub model_info_container: ContainerStyle,
+    pub model: Interactive<ContainedText>,
     pub remaining_tokens: ContainedText,
     pub no_remaining_tokens: ContainedText,
     pub api_key_editor: FieldEditor,

styles/src/styleTree/assistant.ts 🔗

@@ -11,7 +11,8 @@ export default function assistant(colorScheme: ColorScheme) {
       },
       header: {
         border: border(layer, "default", { bottom: true, top: true }),
-        margin: { bottom: 6, top: 6 }
+        margin: { bottom: 6, top: 6 },
+        background: editor(colorScheme).background
       },
       user_sender: {
         ...text(layer, "sans", "default", { size: "sm", weight: "bold" }),
@@ -23,17 +24,32 @@ export default function assistant(colorScheme: ColorScheme) {
         margin: { top: 2, left: 8 },
         ...text(layer, "sans", "default", { size: "2xs" }),
       },
-      remaining_tokens: {
-        padding: 4,
+      model_info_container: {
         margin: { right: 16, top: 4 },
+      },
+      model: {
         background: background(layer, "on"),
+        border: border(layer, "on", { overlay: true }),
+        padding: 4,
+        cornerRadius: 4,
+        ...text(layer, "sans", "default", { size: "xs" }),
+        hover: {
+          background: background(layer, "on", "hovered"),
+        }
+      },
+      remaining_tokens: {
+        background: background(layer, "on"),
+        border: border(layer, "on", { overlay: true }),
+        padding: 4,
+        margin: { left: 4 },
         cornerRadius: 4,
         ...text(layer, "sans", "positive", { size: "xs" }),
       },
       no_remaining_tokens: {
-        padding: 4,
-        margin: { right: 16, top: 4 },
         background: background(layer, "on"),
+        border: border(layer, "on", { overlay: true }),
+        padding: 4,
+        margin: { left: 4 },
         cornerRadius: 4,
         ...text(layer, "sans", "negative", { size: "xs" }),
       },