assistant.rs

  1pub mod assistant_panel;
  2pub mod assistant_settings;
  3mod completion_provider;
  4mod context_store;
  5mod inline_assistant;
  6mod model_selector;
  7mod prompt_library;
  8mod prompts;
  9mod search;
 10mod slash_command;
 11mod streaming_diff;
 12mod terminal_inline_assistant;
 13
 14pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
 15use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel};
 16use assistant_slash_command::SlashCommandRegistry;
 17use client::{proto, Client};
 18use command_palette_hooks::CommandPaletteFilter;
 19pub(crate) use completion_provider::*;
 20pub(crate) use context_store::*;
 21use fs::Fs;
 22use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
 23use indexed_docs::IndexedDocsRegistry;
 24pub(crate) use inline_assistant::*;
 25pub(crate) use model_selector::*;
 26use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
 27use serde::{Deserialize, Serialize};
 28use settings::{Settings, SettingsStore};
 29use slash_command::{
 30    active_command, default_command, diagnostics_command, docs_command, fetch_command,
 31    file_command, now_command, project_command, prompt_command, search_command, tabs_command,
 32    term_command,
 33};
 34use std::{
 35    fmt::{self, Display},
 36    sync::Arc,
 37};
 38pub(crate) use streaming_diff::*;
 39
 40actions!(
 41    assistant,
 42    [
 43        Assist,
 44        Split,
 45        CycleMessageRole,
 46        QuoteSelection,
 47        InsertIntoEditor,
 48        ToggleFocus,
 49        ResetKey,
 50        InlineAssist,
 51        InsertActivePrompt,
 52        DeployHistory,
 53        DeployPromptLibrary,
 54        ApplyEdit,
 55        ConfirmCommand,
 56        ToggleModelSelector
 57    ]
 58);
 59
 60#[derive(
 61    Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
 62)]
 63struct MessageId(usize);
 64
 65#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 66#[serde(rename_all = "lowercase")]
 67pub enum Role {
 68    User,
 69    Assistant,
 70    System,
 71}
 72
 73impl Role {
 74    pub fn cycle(&mut self) {
 75        *self = match self {
 76            Role::User => Role::Assistant,
 77            Role::Assistant => Role::System,
 78            Role::System => Role::User,
 79        }
 80    }
 81}
 82
 83impl Display for Role {
 84    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
 85        match self {
 86            Role::User => write!(f, "user"),
 87            Role::Assistant => write!(f, "assistant"),
 88            Role::System => write!(f, "system"),
 89        }
 90    }
 91}
 92
 93#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
 94pub enum LanguageModel {
 95    Cloud(CloudModel),
 96    OpenAi(OpenAiModel),
 97    Anthropic(AnthropicModel),
 98    Ollama(OllamaModel),
 99}
100
101impl Default for LanguageModel {
102    fn default() -> Self {
103        LanguageModel::Cloud(CloudModel::default())
104    }
105}
106
107impl LanguageModel {
108    pub fn telemetry_id(&self) -> String {
109        match self {
110            LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
111            LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
112            LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
113            LanguageModel::Ollama(model) => format!("ollama/{}", model.id()),
114        }
115    }
116
117    pub fn display_name(&self) -> String {
118        match self {
119            LanguageModel::OpenAi(model) => model.display_name().into(),
120            LanguageModel::Anthropic(model) => model.display_name().into(),
121            LanguageModel::Cloud(model) => model.display_name().into(),
122            LanguageModel::Ollama(model) => model.display_name().into(),
123        }
124    }
125
126    pub fn max_token_count(&self) -> usize {
127        match self {
128            LanguageModel::OpenAi(model) => model.max_token_count(),
129            LanguageModel::Anthropic(model) => model.max_token_count(),
130            LanguageModel::Cloud(model) => model.max_token_count(),
131            LanguageModel::Ollama(model) => model.max_token_count(),
132        }
133    }
134
135    pub fn id(&self) -> &str {
136        match self {
137            LanguageModel::OpenAi(model) => model.id(),
138            LanguageModel::Anthropic(model) => model.id(),
139            LanguageModel::Cloud(model) => model.id(),
140            LanguageModel::Ollama(model) => model.id(),
141        }
142    }
143}
144
145#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
146pub struct LanguageModelRequestMessage {
147    pub role: Role,
148    pub content: String,
149}
150
151impl LanguageModelRequestMessage {
152    pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
153        proto::LanguageModelRequestMessage {
154            role: match self.role {
155                Role::User => proto::LanguageModelRole::LanguageModelUser,
156                Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
157                Role::System => proto::LanguageModelRole::LanguageModelSystem,
158            } as i32,
159            content: self.content.clone(),
160            tool_calls: Vec::new(),
161            tool_call_id: None,
162        }
163    }
164}
165
166#[derive(Debug, Default, Serialize, Deserialize)]
167pub struct LanguageModelRequest {
168    pub model: LanguageModel,
169    pub messages: Vec<LanguageModelRequestMessage>,
170    pub stop: Vec<String>,
171    pub temperature: f32,
172}
173
174impl LanguageModelRequest {
175    pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
176        proto::CompleteWithLanguageModel {
177            model: self.model.id().to_string(),
178            messages: self.messages.iter().map(|m| m.to_proto()).collect(),
179            stop: self.stop.clone(),
180            temperature: self.temperature,
181            tool_choice: None,
182            tools: Vec::new(),
183        }
184    }
185
186    /// Before we send the request to the server, we can perform fixups on it appropriate to the model.
187    pub fn preprocess(&mut self) {
188        match &self.model {
189            LanguageModel::OpenAi(_) => {}
190            LanguageModel::Anthropic(_) => {}
191            LanguageModel::Ollama(_) => {}
192            LanguageModel::Cloud(model) => match model {
193                CloudModel::Claude3Opus
194                | CloudModel::Claude3Sonnet
195                | CloudModel::Claude3Haiku
196                | CloudModel::Claude3_5Sonnet => {
197                    preprocess_anthropic_request(self);
198                }
199                _ => {}
200            },
201        }
202    }
203}
204
205#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
206pub struct LanguageModelResponseMessage {
207    pub role: Option<Role>,
208    pub content: Option<String>,
209}
210
211#[derive(Deserialize, Debug)]
212pub struct LanguageModelUsage {
213    pub prompt_tokens: u32,
214    pub completion_tokens: u32,
215    pub total_tokens: u32,
216}
217
218#[derive(Deserialize, Debug)]
219pub struct LanguageModelChoiceDelta {
220    pub index: u32,
221    pub delta: LanguageModelResponseMessage,
222    pub finish_reason: Option<String>,
223}
224
225#[derive(Clone, Debug, Serialize, Deserialize)]
226struct MessageMetadata {
227    role: Role,
228    status: MessageStatus,
229}
230
231#[derive(Clone, Debug, Serialize, Deserialize)]
232enum MessageStatus {
233    Pending,
234    Done,
235    Error(SharedString),
236}
237
238/// The state pertaining to the Assistant.
239#[derive(Default)]
240struct Assistant {
241    /// Whether the Assistant is enabled.
242    enabled: bool,
243}
244
245impl Global for Assistant {}
246
247impl Assistant {
248    const NAMESPACE: &'static str = "assistant";
249
250    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
251        if self.enabled == enabled {
252            return;
253        }
254
255        self.enabled = enabled;
256
257        if !enabled {
258            CommandPaletteFilter::update_global(cx, |filter, _cx| {
259                filter.hide_namespace(Self::NAMESPACE);
260            });
261
262            return;
263        }
264
265        CommandPaletteFilter::update_global(cx, |filter, _cx| {
266            filter.show_namespace(Self::NAMESPACE);
267        });
268    }
269}
270
271pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
272    cx.set_global(Assistant::default());
273    AssistantSettings::register(cx);
274
275    cx.spawn(|mut cx| {
276        let client = client.clone();
277        async move {
278            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
279            let semantic_index = SemanticIndex::new(
280                paths::embeddings_dir().join("semantic-index-db.0.mdb"),
281                Arc::new(embedding_provider),
282                &mut cx,
283            )
284            .await?;
285            cx.update(|cx| cx.set_global(semantic_index))
286        }
287    })
288    .detach();
289
290    prompt_library::init(cx);
291    completion_provider::init(client.clone(), cx);
292    assistant_slash_command::init(cx);
293    register_slash_commands(cx);
294    assistant_panel::init(cx);
295    inline_assistant::init(fs.clone(), client.telemetry().clone(), cx);
296    terminal_inline_assistant::init(fs.clone(), client.telemetry().clone(), cx);
297    IndexedDocsRegistry::init_global(cx);
298
299    CommandPaletteFilter::update_global(cx, |filter, _cx| {
300        filter.hide_namespace(Assistant::NAMESPACE);
301    });
302    Assistant::update_global(cx, |assistant, cx| {
303        let settings = AssistantSettings::get_global(cx);
304
305        assistant.set_enabled(settings.enabled, cx);
306    });
307    cx.observe_global::<SettingsStore>(|cx| {
308        Assistant::update_global(cx, |assistant, cx| {
309            let settings = AssistantSettings::get_global(cx);
310            assistant.set_enabled(settings.enabled, cx);
311        });
312    })
313    .detach();
314}
315
316fn register_slash_commands(cx: &mut AppContext) {
317    let slash_command_registry = SlashCommandRegistry::global(cx);
318    slash_command_registry.register_command(file_command::FileSlashCommand, true);
319    slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
320    slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
321    slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
322    slash_command_registry.register_command(search_command::SearchSlashCommand, true);
323    slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
324    slash_command_registry.register_command(default_command::DefaultSlashCommand, true);
325    slash_command_registry.register_command(term_command::TermSlashCommand, true);
326    slash_command_registry.register_command(now_command::NowSlashCommand, true);
327    slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
328    slash_command_registry.register_command(docs_command::DocsSlashCommand, true);
329    slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
330}
331
332pub fn humanize_token_count(count: usize) -> String {
333    match count {
334        0..=999 => count.to_string(),
335        1000..=9999 => {
336            let thousands = count / 1000;
337            let hundreds = (count % 1000 + 50) / 100;
338            if hundreds == 0 {
339                format!("{}k", thousands)
340            } else if hundreds == 10 {
341                format!("{}k", thousands + 1)
342            } else {
343                format!("{}.{}k", thousands, hundreds)
344            }
345        }
346        _ => format!("{}k", (count + 500) / 1000),
347    }
348}
349
350#[cfg(test)]
351#[ctor::ctor]
352fn init_logger() {
353    if std::env::var("RUST_LOG").is_ok() {
354        env_logger::init();
355    }
356}