assistant.rs

  1pub mod assistant_panel;
  2pub mod assistant_settings;
  3mod completion_provider;
  4mod context;
  5pub mod context_store;
  6mod inline_assistant;
  7mod model_selector;
  8mod prompt_library;
  9mod prompts;
 10mod search;
 11mod slash_command;
 12mod streaming_diff;
 13mod terminal_inline_assistant;
 14
 15pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
 16use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel};
 17use assistant_slash_command::SlashCommandRegistry;
 18use client::{proto, Client};
 19use command_palette_hooks::CommandPaletteFilter;
 20pub use completion_provider::*;
 21pub use context::*;
 22pub use context_store::*;
 23use fs::Fs;
 24use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
 25use indexed_docs::IndexedDocsRegistry;
 26pub(crate) use inline_assistant::*;
 27pub(crate) use model_selector::*;
 28use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
 29use serde::{Deserialize, Serialize};
 30use settings::{Settings, SettingsStore};
 31use slash_command::{
 32    active_command, default_command, diagnostics_command, docs_command, fetch_command,
 33    file_command, now_command, project_command, prompt_command, search_command, symbols_command,
 34    tabs_command, term_command,
 35};
 36use std::{
 37    fmt::{self, Display},
 38    sync::Arc,
 39};
 40pub(crate) use streaming_diff::*;
 41
 42actions!(
 43    assistant,
 44    [
 45        Assist,
 46        Split,
 47        CycleMessageRole,
 48        QuoteSelection,
 49        InsertIntoEditor,
 50        ToggleFocus,
 51        ResetKey,
 52        InlineAssist,
 53        InsertActivePrompt,
 54        DeployHistory,
 55        DeployPromptLibrary,
 56        ApplyEdit,
 57        ConfirmCommand,
 58        ToggleModelSelector
 59    ]
 60);
 61
 62#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
 63pub struct MessageId(clock::Lamport);
 64
 65impl MessageId {
 66    pub fn as_u64(self) -> u64 {
 67        self.0.as_u64()
 68    }
 69}
 70
 71#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 72#[serde(rename_all = "lowercase")]
 73pub enum Role {
 74    User,
 75    Assistant,
 76    System,
 77}
 78
 79impl Role {
 80    pub fn from_proto(role: i32) -> Role {
 81        match proto::LanguageModelRole::from_i32(role) {
 82            Some(proto::LanguageModelRole::LanguageModelUser) => Role::User,
 83            Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant,
 84            Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System,
 85            Some(proto::LanguageModelRole::LanguageModelTool) => Role::System,
 86            None => Role::User,
 87        }
 88    }
 89
 90    pub fn to_proto(&self) -> proto::LanguageModelRole {
 91        match self {
 92            Role::User => proto::LanguageModelRole::LanguageModelUser,
 93            Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
 94            Role::System => proto::LanguageModelRole::LanguageModelSystem,
 95        }
 96    }
 97
 98    pub fn cycle(self) -> Role {
 99        match self {
100            Role::User => Role::Assistant,
101            Role::Assistant => Role::System,
102            Role::System => Role::User,
103        }
104    }
105}
106
107impl Display for Role {
108    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
109        match self {
110            Role::User => write!(f, "user"),
111            Role::Assistant => write!(f, "assistant"),
112            Role::System => write!(f, "system"),
113        }
114    }
115}
116
117#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
118pub enum LanguageModel {
119    Cloud(CloudModel),
120    OpenAi(OpenAiModel),
121    Anthropic(AnthropicModel),
122    Ollama(OllamaModel),
123}
124
125impl Default for LanguageModel {
126    fn default() -> Self {
127        LanguageModel::Cloud(CloudModel::default())
128    }
129}
130
131impl LanguageModel {
132    pub fn telemetry_id(&self) -> String {
133        match self {
134            LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
135            LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
136            LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
137            LanguageModel::Ollama(model) => format!("ollama/{}", model.id()),
138        }
139    }
140
141    pub fn display_name(&self) -> String {
142        match self {
143            LanguageModel::OpenAi(model) => model.display_name().into(),
144            LanguageModel::Anthropic(model) => model.display_name().into(),
145            LanguageModel::Cloud(model) => model.display_name().into(),
146            LanguageModel::Ollama(model) => model.display_name().into(),
147        }
148    }
149
150    pub fn max_token_count(&self) -> usize {
151        match self {
152            LanguageModel::OpenAi(model) => model.max_token_count(),
153            LanguageModel::Anthropic(model) => model.max_token_count(),
154            LanguageModel::Cloud(model) => model.max_token_count(),
155            LanguageModel::Ollama(model) => model.max_token_count(),
156        }
157    }
158
159    pub fn id(&self) -> &str {
160        match self {
161            LanguageModel::OpenAi(model) => model.id(),
162            LanguageModel::Anthropic(model) => model.id(),
163            LanguageModel::Cloud(model) => model.id(),
164            LanguageModel::Ollama(model) => model.id(),
165        }
166    }
167}
168
169#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
170pub struct LanguageModelRequestMessage {
171    pub role: Role,
172    pub content: String,
173}
174
175impl LanguageModelRequestMessage {
176    pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
177        proto::LanguageModelRequestMessage {
178            role: self.role.to_proto() as i32,
179            content: self.content.clone(),
180            tool_calls: Vec::new(),
181            tool_call_id: None,
182        }
183    }
184}
185
186#[derive(Debug, Default, Serialize, Deserialize)]
187pub struct LanguageModelRequest {
188    pub model: LanguageModel,
189    pub messages: Vec<LanguageModelRequestMessage>,
190    pub stop: Vec<String>,
191    pub temperature: f32,
192}
193
194impl LanguageModelRequest {
195    pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
196        proto::CompleteWithLanguageModel {
197            model: self.model.id().to_string(),
198            messages: self.messages.iter().map(|m| m.to_proto()).collect(),
199            stop: self.stop.clone(),
200            temperature: self.temperature,
201            tool_choice: None,
202            tools: Vec::new(),
203        }
204    }
205
206    /// Before we send the request to the server, we can perform fixups on it appropriate to the model.
207    pub fn preprocess(&mut self) {
208        match &self.model {
209            LanguageModel::OpenAi(_) => {}
210            LanguageModel::Anthropic(_) => {}
211            LanguageModel::Ollama(_) => {}
212            LanguageModel::Cloud(model) => match model {
213                CloudModel::Claude3Opus
214                | CloudModel::Claude3Sonnet
215                | CloudModel::Claude3Haiku
216                | CloudModel::Claude3_5Sonnet => {
217                    preprocess_anthropic_request(self);
218                }
219                _ => {}
220            },
221        }
222    }
223}
224
225#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
226pub struct LanguageModelResponseMessage {
227    pub role: Option<Role>,
228    pub content: Option<String>,
229}
230
231#[derive(Deserialize, Debug)]
232pub struct LanguageModelUsage {
233    pub prompt_tokens: u32,
234    pub completion_tokens: u32,
235    pub total_tokens: u32,
236}
237
238#[derive(Deserialize, Debug)]
239pub struct LanguageModelChoiceDelta {
240    pub index: u32,
241    pub delta: LanguageModelResponseMessage,
242    pub finish_reason: Option<String>,
243}
244
245#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
246pub enum MessageStatus {
247    Pending,
248    Done,
249    Error(SharedString),
250}
251
252impl MessageStatus {
253    pub fn from_proto(status: proto::ContextMessageStatus) -> MessageStatus {
254        match status.variant {
255            Some(proto::context_message_status::Variant::Pending(_)) => MessageStatus::Pending,
256            Some(proto::context_message_status::Variant::Done(_)) => MessageStatus::Done,
257            Some(proto::context_message_status::Variant::Error(error)) => {
258                MessageStatus::Error(error.message.into())
259            }
260            None => MessageStatus::Pending,
261        }
262    }
263
264    pub fn to_proto(&self) -> proto::ContextMessageStatus {
265        match self {
266            MessageStatus::Pending => proto::ContextMessageStatus {
267                variant: Some(proto::context_message_status::Variant::Pending(
268                    proto::context_message_status::Pending {},
269                )),
270            },
271            MessageStatus::Done => proto::ContextMessageStatus {
272                variant: Some(proto::context_message_status::Variant::Done(
273                    proto::context_message_status::Done {},
274                )),
275            },
276            MessageStatus::Error(message) => proto::ContextMessageStatus {
277                variant: Some(proto::context_message_status::Variant::Error(
278                    proto::context_message_status::Error {
279                        message: message.to_string(),
280                    },
281                )),
282            },
283        }
284    }
285}
286
287/// The state pertaining to the Assistant.
288#[derive(Default)]
289struct Assistant {
290    /// Whether the Assistant is enabled.
291    enabled: bool,
292}
293
294impl Global for Assistant {}
295
296impl Assistant {
297    const NAMESPACE: &'static str = "assistant";
298
299    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
300        if self.enabled == enabled {
301            return;
302        }
303
304        self.enabled = enabled;
305
306        if !enabled {
307            CommandPaletteFilter::update_global(cx, |filter, _cx| {
308                filter.hide_namespace(Self::NAMESPACE);
309            });
310
311            return;
312        }
313
314        CommandPaletteFilter::update_global(cx, |filter, _cx| {
315            filter.show_namespace(Self::NAMESPACE);
316        });
317    }
318}
319
320pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
321    cx.set_global(Assistant::default());
322    AssistantSettings::register(cx);
323
324    cx.spawn(|mut cx| {
325        let client = client.clone();
326        async move {
327            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
328            let semantic_index = SemanticIndex::new(
329                paths::embeddings_dir().join("semantic-index-db.0.mdb"),
330                Arc::new(embedding_provider),
331                &mut cx,
332            )
333            .await?;
334            cx.update(|cx| cx.set_global(semantic_index))
335        }
336    })
337    .detach();
338
339    context_store::init(&client);
340    prompt_library::init(cx);
341    completion_provider::init(client.clone(), cx);
342    assistant_slash_command::init(cx);
343    register_slash_commands(cx);
344    assistant_panel::init(cx);
345    inline_assistant::init(fs.clone(), client.telemetry().clone(), cx);
346    terminal_inline_assistant::init(fs.clone(), client.telemetry().clone(), cx);
347    IndexedDocsRegistry::init_global(cx);
348
349    CommandPaletteFilter::update_global(cx, |filter, _cx| {
350        filter.hide_namespace(Assistant::NAMESPACE);
351    });
352    Assistant::update_global(cx, |assistant, cx| {
353        let settings = AssistantSettings::get_global(cx);
354
355        assistant.set_enabled(settings.enabled, cx);
356    });
357    cx.observe_global::<SettingsStore>(|cx| {
358        Assistant::update_global(cx, |assistant, cx| {
359            let settings = AssistantSettings::get_global(cx);
360            assistant.set_enabled(settings.enabled, cx);
361        });
362    })
363    .detach();
364}
365
366fn register_slash_commands(cx: &mut AppContext) {
367    let slash_command_registry = SlashCommandRegistry::global(cx);
368    slash_command_registry.register_command(file_command::FileSlashCommand, true);
369    slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
370    slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true);
371    slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
372    slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
373    slash_command_registry.register_command(search_command::SearchSlashCommand, true);
374    slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
375    slash_command_registry.register_command(default_command::DefaultSlashCommand, true);
376    slash_command_registry.register_command(term_command::TermSlashCommand, true);
377    slash_command_registry.register_command(now_command::NowSlashCommand, true);
378    slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
379    slash_command_registry.register_command(docs_command::DocsSlashCommand, true);
380    slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
381}
382
383pub fn humanize_token_count(count: usize) -> String {
384    match count {
385        0..=999 => count.to_string(),
386        1000..=9999 => {
387            let thousands = count / 1000;
388            let hundreds = (count % 1000 + 50) / 100;
389            if hundreds == 0 {
390                format!("{}k", thousands)
391            } else if hundreds == 10 {
392                format!("{}k", thousands + 1)
393            } else {
394                format!("{}.{}k", thousands, hundreds)
395            }
396        }
397        _ => format!("{}k", (count + 500) / 1000),
398    }
399}
400
401#[cfg(test)]
402#[ctor::ctor]
403fn init_logger() {
404    if std::env::var("RUST_LOG").is_ok() {
405        env_logger::init();
406    }
407}