assistant.rs

  1pub mod assistant_panel;
  2pub mod assistant_settings;
  3mod completion_provider;
  4mod context_store;
  5mod inline_assistant;
  6mod model_selector;
  7mod prompt_library;
  8mod prompts;
  9mod search;
 10mod slash_command;
 11mod streaming_diff;
 12mod terminal_inline_assistant;
 13
 14pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
 15use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel};
 16use assistant_slash_command::SlashCommandRegistry;
 17use client::{proto, Client};
 18use command_palette_hooks::CommandPaletteFilter;
 19pub(crate) use completion_provider::*;
 20pub(crate) use context_store::*;
 21use fs::Fs;
 22use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
 23use indexed_docs::IndexedDocsRegistry;
 24pub(crate) use inline_assistant::*;
 25pub(crate) use model_selector::*;
 26use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
 27use serde::{Deserialize, Serialize};
 28use settings::{Settings, SettingsStore};
 29use slash_command::{
 30    active_command, default_command, diagnostics_command, docs_command, fetch_command,
 31    file_command, now_command, project_command, prompt_command, search_command, tabs_command,
 32    term_command,
 33};
 34use std::{
 35    fmt::{self, Display},
 36    sync::Arc,
 37};
 38pub(crate) use streaming_diff::*;
 39
 40actions!(
 41    assistant,
 42    [
 43        Assist,
 44        Split,
 45        CycleMessageRole,
 46        QuoteSelection,
 47        InsertIntoEditor,
 48        ToggleFocus,
 49        ResetKey,
 50        InlineAssist,
 51        InsertActivePrompt,
 52        ToggleHistory,
 53        ApplyEdit,
 54        ConfirmCommand,
 55        ToggleModelSelector
 56    ]
 57);
 58
 59#[derive(
 60    Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
 61)]
 62struct MessageId(usize);
 63
 64#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 65#[serde(rename_all = "lowercase")]
 66pub enum Role {
 67    User,
 68    Assistant,
 69    System,
 70}
 71
 72impl Role {
 73    pub fn cycle(&mut self) {
 74        *self = match self {
 75            Role::User => Role::Assistant,
 76            Role::Assistant => Role::System,
 77            Role::System => Role::User,
 78        }
 79    }
 80}
 81
 82impl Display for Role {
 83    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
 84        match self {
 85            Role::User => write!(f, "user"),
 86            Role::Assistant => write!(f, "assistant"),
 87            Role::System => write!(f, "system"),
 88        }
 89    }
 90}
 91
 92#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
 93pub enum LanguageModel {
 94    Cloud(CloudModel),
 95    OpenAi(OpenAiModel),
 96    Anthropic(AnthropicModel),
 97    Ollama(OllamaModel),
 98}
 99
100impl Default for LanguageModel {
101    fn default() -> Self {
102        LanguageModel::Cloud(CloudModel::default())
103    }
104}
105
106impl LanguageModel {
107    pub fn telemetry_id(&self) -> String {
108        match self {
109            LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
110            LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
111            LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
112            LanguageModel::Ollama(model) => format!("ollama/{}", model.id()),
113        }
114    }
115
116    pub fn display_name(&self) -> String {
117        match self {
118            LanguageModel::OpenAi(model) => model.display_name().into(),
119            LanguageModel::Anthropic(model) => model.display_name().into(),
120            LanguageModel::Cloud(model) => model.display_name().into(),
121            LanguageModel::Ollama(model) => model.display_name().into(),
122        }
123    }
124
125    pub fn max_token_count(&self) -> usize {
126        match self {
127            LanguageModel::OpenAi(model) => model.max_token_count(),
128            LanguageModel::Anthropic(model) => model.max_token_count(),
129            LanguageModel::Cloud(model) => model.max_token_count(),
130            LanguageModel::Ollama(model) => model.max_token_count(),
131        }
132    }
133
134    pub fn id(&self) -> &str {
135        match self {
136            LanguageModel::OpenAi(model) => model.id(),
137            LanguageModel::Anthropic(model) => model.id(),
138            LanguageModel::Cloud(model) => model.id(),
139            LanguageModel::Ollama(model) => model.id(),
140        }
141    }
142}
143
144#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
145pub struct LanguageModelRequestMessage {
146    pub role: Role,
147    pub content: String,
148}
149
150impl LanguageModelRequestMessage {
151    pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
152        proto::LanguageModelRequestMessage {
153            role: match self.role {
154                Role::User => proto::LanguageModelRole::LanguageModelUser,
155                Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
156                Role::System => proto::LanguageModelRole::LanguageModelSystem,
157            } as i32,
158            content: self.content.clone(),
159            tool_calls: Vec::new(),
160            tool_call_id: None,
161        }
162    }
163}
164
165#[derive(Debug, Default, Serialize)]
166pub struct LanguageModelRequest {
167    pub model: LanguageModel,
168    pub messages: Vec<LanguageModelRequestMessage>,
169    pub stop: Vec<String>,
170    pub temperature: f32,
171}
172
173impl LanguageModelRequest {
174    pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
175        proto::CompleteWithLanguageModel {
176            model: self.model.id().to_string(),
177            messages: self.messages.iter().map(|m| m.to_proto()).collect(),
178            stop: self.stop.clone(),
179            temperature: self.temperature,
180            tool_choice: None,
181            tools: Vec::new(),
182        }
183    }
184
185    /// Before we send the request to the server, we can perform fixups on it appropriate to the model.
186    pub fn preprocess(&mut self) {
187        match &self.model {
188            LanguageModel::OpenAi(_) => {}
189            LanguageModel::Anthropic(_) => {}
190            LanguageModel::Ollama(_) => {}
191            LanguageModel::Cloud(model) => match model {
192                CloudModel::Claude3Opus
193                | CloudModel::Claude3Sonnet
194                | CloudModel::Claude3Haiku
195                | CloudModel::Claude3_5Sonnet => {
196                    preprocess_anthropic_request(self);
197                }
198                _ => {}
199            },
200        }
201    }
202}
203
204#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
205pub struct LanguageModelResponseMessage {
206    pub role: Option<Role>,
207    pub content: Option<String>,
208}
209
210#[derive(Deserialize, Debug)]
211pub struct LanguageModelUsage {
212    pub prompt_tokens: u32,
213    pub completion_tokens: u32,
214    pub total_tokens: u32,
215}
216
217#[derive(Deserialize, Debug)]
218pub struct LanguageModelChoiceDelta {
219    pub index: u32,
220    pub delta: LanguageModelResponseMessage,
221    pub finish_reason: Option<String>,
222}
223
224#[derive(Clone, Debug, Serialize, Deserialize)]
225struct MessageMetadata {
226    role: Role,
227    status: MessageStatus,
228}
229
230#[derive(Clone, Debug, Serialize, Deserialize)]
231enum MessageStatus {
232    Pending,
233    Done,
234    Error(SharedString),
235}
236
237/// The state pertaining to the Assistant.
238#[derive(Default)]
239struct Assistant {
240    /// Whether the Assistant is enabled.
241    enabled: bool,
242}
243
244impl Global for Assistant {}
245
246impl Assistant {
247    const NAMESPACE: &'static str = "assistant";
248
249    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
250        if self.enabled == enabled {
251            return;
252        }
253
254        self.enabled = enabled;
255
256        if !enabled {
257            CommandPaletteFilter::update_global(cx, |filter, _cx| {
258                filter.hide_namespace(Self::NAMESPACE);
259            });
260
261            return;
262        }
263
264        CommandPaletteFilter::update_global(cx, |filter, _cx| {
265            filter.show_namespace(Self::NAMESPACE);
266        });
267    }
268}
269
270pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
271    cx.set_global(Assistant::default());
272    AssistantSettings::register(cx);
273
274    cx.spawn(|mut cx| {
275        let client = client.clone();
276        async move {
277            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
278            let semantic_index = SemanticIndex::new(
279                paths::embeddings_dir().join("semantic-index-db.0.mdb"),
280                Arc::new(embedding_provider),
281                &mut cx,
282            )
283            .await?;
284            cx.update(|cx| cx.set_global(semantic_index))
285        }
286    })
287    .detach();
288
289    prompt_library::init(cx);
290    completion_provider::init(client.clone(), cx);
291    assistant_slash_command::init(cx);
292    register_slash_commands(cx);
293    assistant_panel::init(cx);
294    inline_assistant::init(fs.clone(), client.telemetry().clone(), cx);
295    terminal_inline_assistant::init(fs.clone(), client.telemetry().clone(), cx);
296    IndexedDocsRegistry::init_global(cx);
297
298    CommandPaletteFilter::update_global(cx, |filter, _cx| {
299        filter.hide_namespace(Assistant::NAMESPACE);
300    });
301    Assistant::update_global(cx, |assistant, cx| {
302        let settings = AssistantSettings::get_global(cx);
303
304        assistant.set_enabled(settings.enabled, cx);
305    });
306    cx.observe_global::<SettingsStore>(|cx| {
307        Assistant::update_global(cx, |assistant, cx| {
308            let settings = AssistantSettings::get_global(cx);
309            assistant.set_enabled(settings.enabled, cx);
310        });
311    })
312    .detach();
313}
314
315fn register_slash_commands(cx: &mut AppContext) {
316    let slash_command_registry = SlashCommandRegistry::global(cx);
317    slash_command_registry.register_command(file_command::FileSlashCommand, true);
318    slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
319    slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
320    slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
321    slash_command_registry.register_command(search_command::SearchSlashCommand, true);
322    slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
323    slash_command_registry.register_command(default_command::DefaultSlashCommand, true);
324    slash_command_registry.register_command(term_command::TermSlashCommand, true);
325    slash_command_registry.register_command(now_command::NowSlashCommand, true);
326    slash_command_registry.register_command(diagnostics_command::DiagnosticsCommand, true);
327    slash_command_registry.register_command(docs_command::DocsSlashCommand, true);
328    slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
329}
330
331pub fn humanize_token_count(count: usize) -> String {
332    match count {
333        0..=999 => count.to_string(),
334        1000..=9999 => {
335            let thousands = count / 1000;
336            let hundreds = (count % 1000 + 50) / 100;
337            if hundreds == 0 {
338                format!("{}k", thousands)
339            } else if hundreds == 10 {
340                format!("{}k", thousands + 1)
341            } else {
342                format!("{}.{}k", thousands, hundreds)
343            }
344        }
345        _ => format!("{}k", (count + 500) / 1000),
346    }
347}
348
349#[cfg(test)]
350#[ctor::ctor]
351fn init_logger() {
352    if std::env::var("RUST_LOG").is_ok() {
353        env_logger::init();
354    }
355}