assistant.rs

  1pub mod assistant_panel;
  2pub mod assistant_settings;
  3mod completion_provider;
  4mod context_store;
  5mod inline_assistant;
  6mod model_selector;
  7mod prompt_library;
  8mod prompts;
  9mod search;
 10mod slash_command;
 11mod streaming_diff;
 12mod terminal_inline_assistant;
 13
 14pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
 15use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel};
 16use assistant_slash_command::SlashCommandRegistry;
 17use client::{proto, Client};
 18use command_palette_hooks::CommandPaletteFilter;
 19pub(crate) use completion_provider::*;
 20pub(crate) use context_store::*;
 21use fs::Fs;
 22use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
 23pub(crate) use inline_assistant::*;
 24pub(crate) use model_selector::*;
 25use rustdoc::RustdocStore;
 26use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
 27use serde::{Deserialize, Serialize};
 28use settings::{Settings, SettingsStore};
 29use slash_command::{
 30    active_command, default_command, diagnostics_command, fetch_command, file_command, now_command,
 31    project_command, prompt_command, rustdoc_command, search_command, tabs_command, term_command,
 32};
 33use std::{
 34    fmt::{self, Display},
 35    sync::Arc,
 36};
 37pub(crate) use streaming_diff::*;
 38
 39actions!(
 40    assistant,
 41    [
 42        Assist,
 43        Split,
 44        CycleMessageRole,
 45        QuoteSelection,
 46        ToggleFocus,
 47        ResetKey,
 48        InlineAssist,
 49        InsertActivePrompt,
 50        ToggleHistory,
 51        ApplyEdit,
 52        ConfirmCommand,
 53        ToggleModelSelector
 54    ]
 55);
 56
 57#[derive(
 58    Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
 59)]
 60struct MessageId(usize);
 61
 62#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 63#[serde(rename_all = "lowercase")]
 64pub enum Role {
 65    User,
 66    Assistant,
 67    System,
 68}
 69
 70impl Role {
 71    pub fn cycle(&mut self) {
 72        *self = match self {
 73            Role::User => Role::Assistant,
 74            Role::Assistant => Role::System,
 75            Role::System => Role::User,
 76        }
 77    }
 78}
 79
 80impl Display for Role {
 81    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
 82        match self {
 83            Role::User => write!(f, "user"),
 84            Role::Assistant => write!(f, "assistant"),
 85            Role::System => write!(f, "system"),
 86        }
 87    }
 88}
 89
 90#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
 91pub enum LanguageModel {
 92    Cloud(CloudModel),
 93    OpenAi(OpenAiModel),
 94    Anthropic(AnthropicModel),
 95    Ollama(OllamaModel),
 96}
 97
 98impl Default for LanguageModel {
 99    fn default() -> Self {
100        LanguageModel::Cloud(CloudModel::default())
101    }
102}
103
104impl LanguageModel {
105    pub fn telemetry_id(&self) -> String {
106        match self {
107            LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
108            LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
109            LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
110            LanguageModel::Ollama(model) => format!("ollama/{}", model.id()),
111        }
112    }
113
114    pub fn display_name(&self) -> String {
115        match self {
116            LanguageModel::OpenAi(model) => model.display_name().into(),
117            LanguageModel::Anthropic(model) => model.display_name().into(),
118            LanguageModel::Cloud(model) => model.display_name().into(),
119            LanguageModel::Ollama(model) => model.display_name().into(),
120        }
121    }
122
123    pub fn max_token_count(&self) -> usize {
124        match self {
125            LanguageModel::OpenAi(model) => model.max_token_count(),
126            LanguageModel::Anthropic(model) => model.max_token_count(),
127            LanguageModel::Cloud(model) => model.max_token_count(),
128            LanguageModel::Ollama(model) => model.max_token_count(),
129        }
130    }
131
132    pub fn id(&self) -> &str {
133        match self {
134            LanguageModel::OpenAi(model) => model.id(),
135            LanguageModel::Anthropic(model) => model.id(),
136            LanguageModel::Cloud(model) => model.id(),
137            LanguageModel::Ollama(model) => model.id(),
138        }
139    }
140}
141
142#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
143pub struct LanguageModelRequestMessage {
144    pub role: Role,
145    pub content: String,
146}
147
148impl LanguageModelRequestMessage {
149    pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
150        proto::LanguageModelRequestMessage {
151            role: match self.role {
152                Role::User => proto::LanguageModelRole::LanguageModelUser,
153                Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
154                Role::System => proto::LanguageModelRole::LanguageModelSystem,
155            } as i32,
156            content: self.content.clone(),
157            tool_calls: Vec::new(),
158            tool_call_id: None,
159        }
160    }
161}
162
163#[derive(Debug, Default, Serialize)]
164pub struct LanguageModelRequest {
165    pub model: LanguageModel,
166    pub messages: Vec<LanguageModelRequestMessage>,
167    pub stop: Vec<String>,
168    pub temperature: f32,
169}
170
171impl LanguageModelRequest {
172    pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
173        proto::CompleteWithLanguageModel {
174            model: self.model.id().to_string(),
175            messages: self.messages.iter().map(|m| m.to_proto()).collect(),
176            stop: self.stop.clone(),
177            temperature: self.temperature,
178            tool_choice: None,
179            tools: Vec::new(),
180        }
181    }
182
183    /// Before we send the request to the server, we can perform fixups on it appropriate to the model.
184    pub fn preprocess(&mut self) {
185        match &self.model {
186            LanguageModel::OpenAi(_) => {}
187            LanguageModel::Anthropic(_) => {}
188            LanguageModel::Ollama(_) => {}
189            LanguageModel::Cloud(model) => match model {
190                CloudModel::Claude3Opus
191                | CloudModel::Claude3Sonnet
192                | CloudModel::Claude3Haiku
193                | CloudModel::Claude3_5Sonnet => {
194                    preprocess_anthropic_request(self);
195                }
196                _ => {}
197            },
198        }
199    }
200}
201
202#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
203pub struct LanguageModelResponseMessage {
204    pub role: Option<Role>,
205    pub content: Option<String>,
206}
207
208#[derive(Deserialize, Debug)]
209pub struct LanguageModelUsage {
210    pub prompt_tokens: u32,
211    pub completion_tokens: u32,
212    pub total_tokens: u32,
213}
214
215#[derive(Deserialize, Debug)]
216pub struct LanguageModelChoiceDelta {
217    pub index: u32,
218    pub delta: LanguageModelResponseMessage,
219    pub finish_reason: Option<String>,
220}
221
222#[derive(Clone, Debug, Serialize, Deserialize)]
223struct MessageMetadata {
224    role: Role,
225    status: MessageStatus,
226}
227
228#[derive(Clone, Debug, Serialize, Deserialize)]
229enum MessageStatus {
230    Pending,
231    Done,
232    Error(SharedString),
233}
234
235/// The state pertaining to the Assistant.
236#[derive(Default)]
237struct Assistant {
238    /// Whether the Assistant is enabled.
239    enabled: bool,
240}
241
242impl Global for Assistant {}
243
244impl Assistant {
245    const NAMESPACE: &'static str = "assistant";
246
247    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
248        if self.enabled == enabled {
249            return;
250        }
251
252        self.enabled = enabled;
253
254        if !enabled {
255            CommandPaletteFilter::update_global(cx, |filter, _cx| {
256                filter.hide_namespace(Self::NAMESPACE);
257            });
258
259            return;
260        }
261
262        CommandPaletteFilter::update_global(cx, |filter, _cx| {
263            filter.show_namespace(Self::NAMESPACE);
264        });
265    }
266}
267
268pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
269    cx.set_global(Assistant::default());
270    AssistantSettings::register(cx);
271
272    cx.spawn(|mut cx| {
273        let client = client.clone();
274        async move {
275            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
276            let semantic_index = SemanticIndex::new(
277                paths::embeddings_dir().join("semantic-index-db.0.mdb"),
278                Arc::new(embedding_provider),
279                &mut cx,
280            )
281            .await?;
282            cx.update(|cx| cx.set_global(semantic_index))
283        }
284    })
285    .detach();
286
287    prompt_library::init(cx);
288    completion_provider::init(client.clone(), cx);
289    assistant_slash_command::init(cx);
290    register_slash_commands(cx);
291    assistant_panel::init(cx);
292    inline_assistant::init(fs.clone(), client.telemetry().clone(), cx);
293    terminal_inline_assistant::init(fs.clone(), client.telemetry().clone(), cx);
294    RustdocStore::init_global(cx);
295
296    CommandPaletteFilter::update_global(cx, |filter, _cx| {
297        filter.hide_namespace(Assistant::NAMESPACE);
298    });
299    Assistant::update_global(cx, |assistant, cx| {
300        let settings = AssistantSettings::get_global(cx);
301
302        assistant.set_enabled(settings.enabled, cx);
303    });
304    cx.observe_global::<SettingsStore>(|cx| {
305        Assistant::update_global(cx, |assistant, cx| {
306            let settings = AssistantSettings::get_global(cx);
307            assistant.set_enabled(settings.enabled, cx);
308        });
309    })
310    .detach();
311}
312
313fn register_slash_commands(cx: &mut AppContext) {
314    let slash_command_registry = SlashCommandRegistry::global(cx);
315    slash_command_registry.register_command(file_command::FileSlashCommand, true);
316    slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
317    slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
318    slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
319    slash_command_registry.register_command(search_command::SearchSlashCommand, true);
320    slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
321    slash_command_registry.register_command(default_command::DefaultSlashCommand, true);
322    slash_command_registry.register_command(term_command::TermSlashCommand, true);
323    slash_command_registry.register_command(now_command::NowSlashCommand, true);
324    slash_command_registry.register_command(diagnostics_command::DiagnosticsCommand, true);
325    slash_command_registry.register_command(rustdoc_command::RustdocSlashCommand, false);
326    slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
327}
328
329pub fn humanize_token_count(count: usize) -> String {
330    match count {
331        0..=999 => count.to_string(),
332        1000..=9999 => {
333            let thousands = count / 1000;
334            let hundreds = (count % 1000 + 50) / 100;
335            if hundreds == 0 {
336                format!("{}k", thousands)
337            } else if hundreds == 10 {
338                format!("{}k", thousands + 1)
339            } else {
340                format!("{}.{}k", thousands, hundreds)
341            }
342        }
343        _ => format!("{}k", (count + 500) / 1000),
344    }
345}
346
347#[cfg(test)]
348#[ctor::ctor]
349fn init_logger() {
350    if std::env::var("RUST_LOG").is_ok() {
351        env_logger::init();
352    }
353}