assistant.rs

  1pub mod assistant_panel;
  2pub mod assistant_settings;
  3mod completion_provider;
  4mod context;
  5pub mod context_store;
  6mod inline_assistant;
  7mod model_selector;
  8mod prompt_library;
  9mod prompts;
 10mod slash_command;
 11mod streaming_diff;
 12mod terminal_inline_assistant;
 13
 14pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
 15use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel};
 16use assistant_slash_command::SlashCommandRegistry;
 17use client::{proto, Client};
 18use command_palette_hooks::CommandPaletteFilter;
 19pub use completion_provider::*;
 20pub use context::*;
 21pub use context_store::*;
 22use fs::Fs;
 23use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal};
 24use indexed_docs::IndexedDocsRegistry;
 25pub(crate) use inline_assistant::*;
 26pub(crate) use model_selector::*;
 27use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
 28use serde::{Deserialize, Serialize};
 29use settings::{Settings, SettingsStore};
 30use slash_command::{
 31    active_command, default_command, diagnostics_command, docs_command, fetch_command,
 32    file_command, now_command, project_command, prompt_command, search_command, symbols_command,
 33    tabs_command, term_command,
 34};
 35use std::{
 36    fmt::{self, Display},
 37    sync::Arc,
 38};
 39pub(crate) use streaming_diff::*;
 40
 41actions!(
 42    assistant,
 43    [
 44        Assist,
 45        Split,
 46        CycleMessageRole,
 47        QuoteSelection,
 48        InsertIntoEditor,
 49        ToggleFocus,
 50        ResetKey,
 51        InsertActivePrompt,
 52        DeployHistory,
 53        DeployPromptLibrary,
 54        ConfirmCommand,
 55        ToggleModelSelector,
 56        DebugEditSteps
 57    ]
 58);
 59
 60#[derive(Clone, Default, Deserialize, PartialEq)]
 61pub struct InlineAssist {
 62    prompt: Option<String>,
 63}
 64
 65impl_actions!(assistant, [InlineAssist]);
 66
 67#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
 68pub struct MessageId(clock::Lamport);
 69
 70impl MessageId {
 71    pub fn as_u64(self) -> u64 {
 72        self.0.as_u64()
 73    }
 74}
 75
 76#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 77#[serde(rename_all = "lowercase")]
 78pub enum Role {
 79    User,
 80    Assistant,
 81    System,
 82}
 83
 84impl Role {
 85    pub fn from_proto(role: i32) -> Role {
 86        match proto::LanguageModelRole::from_i32(role) {
 87            Some(proto::LanguageModelRole::LanguageModelUser) => Role::User,
 88            Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant,
 89            Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System,
 90            Some(proto::LanguageModelRole::LanguageModelTool) => Role::System,
 91            None => Role::User,
 92        }
 93    }
 94
 95    pub fn to_proto(&self) -> proto::LanguageModelRole {
 96        match self {
 97            Role::User => proto::LanguageModelRole::LanguageModelUser,
 98            Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
 99            Role::System => proto::LanguageModelRole::LanguageModelSystem,
100        }
101    }
102
103    pub fn cycle(self) -> Role {
104        match self {
105            Role::User => Role::Assistant,
106            Role::Assistant => Role::System,
107            Role::System => Role::User,
108        }
109    }
110}
111
112impl Display for Role {
113    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
114        match self {
115            Role::User => write!(f, "user"),
116            Role::Assistant => write!(f, "assistant"),
117            Role::System => write!(f, "system"),
118        }
119    }
120}
121
122#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
123pub enum LanguageModel {
124    Cloud(CloudModel),
125    OpenAi(OpenAiModel),
126    Anthropic(AnthropicModel),
127    Ollama(OllamaModel),
128}
129
130impl Default for LanguageModel {
131    fn default() -> Self {
132        LanguageModel::Cloud(CloudModel::default())
133    }
134}
135
136impl LanguageModel {
137    pub fn telemetry_id(&self) -> String {
138        match self {
139            LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
140            LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
141            LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
142            LanguageModel::Ollama(model) => format!("ollama/{}", model.id()),
143        }
144    }
145
146    pub fn display_name(&self) -> String {
147        match self {
148            LanguageModel::OpenAi(model) => model.display_name().into(),
149            LanguageModel::Anthropic(model) => model.display_name().into(),
150            LanguageModel::Cloud(model) => model.display_name().into(),
151            LanguageModel::Ollama(model) => model.display_name().into(),
152        }
153    }
154
155    pub fn max_token_count(&self) -> usize {
156        match self {
157            LanguageModel::OpenAi(model) => model.max_token_count(),
158            LanguageModel::Anthropic(model) => model.max_token_count(),
159            LanguageModel::Cloud(model) => model.max_token_count(),
160            LanguageModel::Ollama(model) => model.max_token_count(),
161        }
162    }
163
164    pub fn id(&self) -> &str {
165        match self {
166            LanguageModel::OpenAi(model) => model.id(),
167            LanguageModel::Anthropic(model) => model.id(),
168            LanguageModel::Cloud(model) => model.id(),
169            LanguageModel::Ollama(model) => model.id(),
170        }
171    }
172}
173
174#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
175pub struct LanguageModelRequestMessage {
176    pub role: Role,
177    pub content: String,
178}
179
180impl LanguageModelRequestMessage {
181    pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
182        proto::LanguageModelRequestMessage {
183            role: self.role.to_proto() as i32,
184            content: self.content.clone(),
185            tool_calls: Vec::new(),
186            tool_call_id: None,
187        }
188    }
189}
190
191#[derive(Debug, Default, Serialize, Deserialize)]
192pub struct LanguageModelRequest {
193    pub model: LanguageModel,
194    pub messages: Vec<LanguageModelRequestMessage>,
195    pub stop: Vec<String>,
196    pub temperature: f32,
197}
198
199impl LanguageModelRequest {
200    pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
201        proto::CompleteWithLanguageModel {
202            model: self.model.id().to_string(),
203            messages: self.messages.iter().map(|m| m.to_proto()).collect(),
204            stop: self.stop.clone(),
205            temperature: self.temperature,
206            tool_choice: None,
207            tools: Vec::new(),
208        }
209    }
210
211    /// Before we send the request to the server, we can perform fixups on it appropriate to the model.
212    pub fn preprocess(&mut self) {
213        match &self.model {
214            LanguageModel::OpenAi(_) => {}
215            LanguageModel::Anthropic(_) => {}
216            LanguageModel::Ollama(_) => {}
217            LanguageModel::Cloud(model) => match model {
218                CloudModel::Claude3Opus
219                | CloudModel::Claude3Sonnet
220                | CloudModel::Claude3Haiku
221                | CloudModel::Claude3_5Sonnet => {
222                    preprocess_anthropic_request(self);
223                }
224                _ => {}
225            },
226        }
227    }
228}
229
230#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
231pub struct LanguageModelResponseMessage {
232    pub role: Option<Role>,
233    pub content: Option<String>,
234}
235
236#[derive(Deserialize, Debug)]
237pub struct LanguageModelUsage {
238    pub prompt_tokens: u32,
239    pub completion_tokens: u32,
240    pub total_tokens: u32,
241}
242
243#[derive(Deserialize, Debug)]
244pub struct LanguageModelChoiceDelta {
245    pub index: u32,
246    pub delta: LanguageModelResponseMessage,
247    pub finish_reason: Option<String>,
248}
249
250#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
251pub enum MessageStatus {
252    Pending,
253    Done,
254    Error(SharedString),
255}
256
257impl MessageStatus {
258    pub fn from_proto(status: proto::ContextMessageStatus) -> MessageStatus {
259        match status.variant {
260            Some(proto::context_message_status::Variant::Pending(_)) => MessageStatus::Pending,
261            Some(proto::context_message_status::Variant::Done(_)) => MessageStatus::Done,
262            Some(proto::context_message_status::Variant::Error(error)) => {
263                MessageStatus::Error(error.message.into())
264            }
265            None => MessageStatus::Pending,
266        }
267    }
268
269    pub fn to_proto(&self) -> proto::ContextMessageStatus {
270        match self {
271            MessageStatus::Pending => proto::ContextMessageStatus {
272                variant: Some(proto::context_message_status::Variant::Pending(
273                    proto::context_message_status::Pending {},
274                )),
275            },
276            MessageStatus::Done => proto::ContextMessageStatus {
277                variant: Some(proto::context_message_status::Variant::Done(
278                    proto::context_message_status::Done {},
279                )),
280            },
281            MessageStatus::Error(message) => proto::ContextMessageStatus {
282                variant: Some(proto::context_message_status::Variant::Error(
283                    proto::context_message_status::Error {
284                        message: message.to_string(),
285                    },
286                )),
287            },
288        }
289    }
290}
291
292/// The state pertaining to the Assistant.
293#[derive(Default)]
294struct Assistant {
295    /// Whether the Assistant is enabled.
296    enabled: bool,
297}
298
299impl Global for Assistant {}
300
301impl Assistant {
302    const NAMESPACE: &'static str = "assistant";
303
304    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
305        if self.enabled == enabled {
306            return;
307        }
308
309        self.enabled = enabled;
310
311        if !enabled {
312            CommandPaletteFilter::update_global(cx, |filter, _cx| {
313                filter.hide_namespace(Self::NAMESPACE);
314            });
315
316            return;
317        }
318
319        CommandPaletteFilter::update_global(cx, |filter, _cx| {
320            filter.show_namespace(Self::NAMESPACE);
321        });
322    }
323}
324
325pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
326    cx.set_global(Assistant::default());
327    AssistantSettings::register(cx);
328
329    cx.spawn(|mut cx| {
330        let client = client.clone();
331        async move {
332            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
333            let semantic_index = SemanticIndex::new(
334                paths::embeddings_dir().join("semantic-index-db.0.mdb"),
335                Arc::new(embedding_provider),
336                &mut cx,
337            )
338            .await?;
339            cx.update(|cx| cx.set_global(semantic_index))
340        }
341    })
342    .detach();
343
344    context_store::init(&client);
345    prompt_library::init(cx);
346    completion_provider::init(client.clone(), cx);
347    assistant_slash_command::init(cx);
348    register_slash_commands(cx);
349    assistant_panel::init(cx);
350    inline_assistant::init(fs.clone(), client.telemetry().clone(), cx);
351    terminal_inline_assistant::init(fs.clone(), client.telemetry().clone(), cx);
352    IndexedDocsRegistry::init_global(cx);
353
354    CommandPaletteFilter::update_global(cx, |filter, _cx| {
355        filter.hide_namespace(Assistant::NAMESPACE);
356    });
357    Assistant::update_global(cx, |assistant, cx| {
358        let settings = AssistantSettings::get_global(cx);
359
360        assistant.set_enabled(settings.enabled, cx);
361    });
362    cx.observe_global::<SettingsStore>(|cx| {
363        Assistant::update_global(cx, |assistant, cx| {
364            let settings = AssistantSettings::get_global(cx);
365            assistant.set_enabled(settings.enabled, cx);
366        });
367    })
368    .detach();
369}
370
371fn register_slash_commands(cx: &mut AppContext) {
372    let slash_command_registry = SlashCommandRegistry::global(cx);
373    slash_command_registry.register_command(file_command::FileSlashCommand, true);
374    slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
375    slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true);
376    slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
377    slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
378    slash_command_registry.register_command(search_command::SearchSlashCommand, true);
379    slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
380    slash_command_registry.register_command(default_command::DefaultSlashCommand, true);
381    slash_command_registry.register_command(term_command::TermSlashCommand, true);
382    slash_command_registry.register_command(now_command::NowSlashCommand, true);
383    slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
384    slash_command_registry.register_command(docs_command::DocsSlashCommand, true);
385    slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
386}
387
388pub fn humanize_token_count(count: usize) -> String {
389    match count {
390        0..=999 => count.to_string(),
391        1000..=9999 => {
392            let thousands = count / 1000;
393            let hundreds = (count % 1000 + 50) / 100;
394            if hundreds == 0 {
395                format!("{}k", thousands)
396            } else if hundreds == 10 {
397                format!("{}k", thousands + 1)
398            } else {
399                format!("{}.{}k", thousands, hundreds)
400            }
401        }
402        _ => format!("{}k", (count + 500) / 1000),
403    }
404}
405
406#[cfg(test)]
407#[ctor::ctor]
408fn init_logger() {
409    if std::env::var("RUST_LOG").is_ok() {
410        env_logger::init();
411    }
412}