assistant.rs

  1pub mod assistant_panel;
  2pub mod assistant_settings;
  3mod completion_provider;
  4mod context_store;
  5mod inline_assistant;
  6mod model_selector;
  7mod prompt_library;
  8mod prompts;
  9mod search;
 10mod slash_command;
 11mod streaming_diff;
 12
 13pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
 14use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel};
 15use assistant_slash_command::SlashCommandRegistry;
 16use client::{proto, Client};
 17use command_palette_hooks::CommandPaletteFilter;
 18pub(crate) use completion_provider::*;
 19pub(crate) use context_store::*;
 20use fs::Fs;
 21use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
 22pub(crate) use inline_assistant::*;
 23pub(crate) use model_selector::*;
 24use rustdoc::RustdocStore;
 25use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
 26use serde::{Deserialize, Serialize};
 27use settings::{Settings, SettingsStore};
 28use slash_command::{
 29    active_command, default_command, diagnostics_command, fetch_command, file_command, now_command,
 30    project_command, prompt_command, rustdoc_command, search_command, tabs_command, term_command,
 31};
 32use std::{
 33    fmt::{self, Display},
 34    sync::Arc,
 35};
 36pub(crate) use streaming_diff::*;
 37
 38actions!(
 39    assistant,
 40    [
 41        Assist,
 42        Split,
 43        CycleMessageRole,
 44        QuoteSelection,
 45        ToggleFocus,
 46        ResetKey,
 47        InlineAssist,
 48        InsertActivePrompt,
 49        ToggleHistory,
 50        ApplyEdit,
 51        ConfirmCommand,
 52        ToggleModelSelector
 53    ]
 54);
 55
 56#[derive(
 57    Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
 58)]
 59struct MessageId(usize);
 60
 61#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 62#[serde(rename_all = "lowercase")]
 63pub enum Role {
 64    User,
 65    Assistant,
 66    System,
 67}
 68
 69impl Role {
 70    pub fn cycle(&mut self) {
 71        *self = match self {
 72            Role::User => Role::Assistant,
 73            Role::Assistant => Role::System,
 74            Role::System => Role::User,
 75        }
 76    }
 77}
 78
 79impl Display for Role {
 80    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
 81        match self {
 82            Role::User => write!(f, "user"),
 83            Role::Assistant => write!(f, "assistant"),
 84            Role::System => write!(f, "system"),
 85        }
 86    }
 87}
 88
 89#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
 90pub enum LanguageModel {
 91    Cloud(CloudModel),
 92    OpenAi(OpenAiModel),
 93    Anthropic(AnthropicModel),
 94    Ollama(OllamaModel),
 95}
 96
 97impl Default for LanguageModel {
 98    fn default() -> Self {
 99        LanguageModel::Cloud(CloudModel::default())
100    }
101}
102
103impl LanguageModel {
104    pub fn telemetry_id(&self) -> String {
105        match self {
106            LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
107            LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
108            LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
109            LanguageModel::Ollama(model) => format!("ollama/{}", model.id()),
110        }
111    }
112
113    pub fn display_name(&self) -> String {
114        match self {
115            LanguageModel::OpenAi(model) => model.display_name().into(),
116            LanguageModel::Anthropic(model) => model.display_name().into(),
117            LanguageModel::Cloud(model) => model.display_name().into(),
118            LanguageModel::Ollama(model) => model.display_name().into(),
119        }
120    }
121
122    pub fn max_token_count(&self) -> usize {
123        match self {
124            LanguageModel::OpenAi(model) => model.max_token_count(),
125            LanguageModel::Anthropic(model) => model.max_token_count(),
126            LanguageModel::Cloud(model) => model.max_token_count(),
127            LanguageModel::Ollama(model) => model.max_token_count(),
128        }
129    }
130
131    pub fn id(&self) -> &str {
132        match self {
133            LanguageModel::OpenAi(model) => model.id(),
134            LanguageModel::Anthropic(model) => model.id(),
135            LanguageModel::Cloud(model) => model.id(),
136            LanguageModel::Ollama(model) => model.id(),
137        }
138    }
139}
140
141#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
142pub struct LanguageModelRequestMessage {
143    pub role: Role,
144    pub content: String,
145}
146
147impl LanguageModelRequestMessage {
148    pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
149        proto::LanguageModelRequestMessage {
150            role: match self.role {
151                Role::User => proto::LanguageModelRole::LanguageModelUser,
152                Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
153                Role::System => proto::LanguageModelRole::LanguageModelSystem,
154            } as i32,
155            content: self.content.clone(),
156            tool_calls: Vec::new(),
157            tool_call_id: None,
158        }
159    }
160}
161
162#[derive(Debug, Default, Serialize)]
163pub struct LanguageModelRequest {
164    pub model: LanguageModel,
165    pub messages: Vec<LanguageModelRequestMessage>,
166    pub stop: Vec<String>,
167    pub temperature: f32,
168}
169
170impl LanguageModelRequest {
171    pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
172        proto::CompleteWithLanguageModel {
173            model: self.model.id().to_string(),
174            messages: self.messages.iter().map(|m| m.to_proto()).collect(),
175            stop: self.stop.clone(),
176            temperature: self.temperature,
177            tool_choice: None,
178            tools: Vec::new(),
179        }
180    }
181
182    /// Before we send the request to the server, we can perform fixups on it appropriate to the model.
183    pub fn preprocess(&mut self) {
184        match &self.model {
185            LanguageModel::OpenAi(_) => {}
186            LanguageModel::Anthropic(_) => {}
187            LanguageModel::Ollama(_) => {}
188            LanguageModel::Cloud(model) => match model {
189                CloudModel::Claude3Opus
190                | CloudModel::Claude3Sonnet
191                | CloudModel::Claude3Haiku
192                | CloudModel::Claude3_5Sonnet => {
193                    preprocess_anthropic_request(self);
194                }
195                _ => {}
196            },
197        }
198    }
199}
200
201#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
202pub struct LanguageModelResponseMessage {
203    pub role: Option<Role>,
204    pub content: Option<String>,
205}
206
207#[derive(Deserialize, Debug)]
208pub struct LanguageModelUsage {
209    pub prompt_tokens: u32,
210    pub completion_tokens: u32,
211    pub total_tokens: u32,
212}
213
214#[derive(Deserialize, Debug)]
215pub struct LanguageModelChoiceDelta {
216    pub index: u32,
217    pub delta: LanguageModelResponseMessage,
218    pub finish_reason: Option<String>,
219}
220
221#[derive(Clone, Debug, Serialize, Deserialize)]
222struct MessageMetadata {
223    role: Role,
224    status: MessageStatus,
225}
226
227#[derive(Clone, Debug, Serialize, Deserialize)]
228enum MessageStatus {
229    Pending,
230    Done,
231    Error(SharedString),
232}
233
234/// The state pertaining to the Assistant.
235#[derive(Default)]
236struct Assistant {
237    /// Whether the Assistant is enabled.
238    enabled: bool,
239}
240
241impl Global for Assistant {}
242
243impl Assistant {
244    const NAMESPACE: &'static str = "assistant";
245
246    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
247        if self.enabled == enabled {
248            return;
249        }
250
251        self.enabled = enabled;
252
253        if !enabled {
254            CommandPaletteFilter::update_global(cx, |filter, _cx| {
255                filter.hide_namespace(Self::NAMESPACE);
256            });
257
258            return;
259        }
260
261        CommandPaletteFilter::update_global(cx, |filter, _cx| {
262            filter.show_namespace(Self::NAMESPACE);
263        });
264    }
265}
266
267pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
268    cx.set_global(Assistant::default());
269    AssistantSettings::register(cx);
270
271    cx.spawn(|mut cx| {
272        let client = client.clone();
273        async move {
274            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
275            let semantic_index = SemanticIndex::new(
276                paths::embeddings_dir().join("semantic-index-db.0.mdb"),
277                Arc::new(embedding_provider),
278                &mut cx,
279            )
280            .await?;
281            cx.update(|cx| cx.set_global(semantic_index))
282        }
283    })
284    .detach();
285
286    prompt_library::init(cx);
287    completion_provider::init(client.clone(), cx);
288    assistant_slash_command::init(cx);
289    register_slash_commands(cx);
290    assistant_panel::init(cx);
291    inline_assistant::init(fs.clone(), client.telemetry().clone(), cx);
292    RustdocStore::init_global(cx);
293
294    CommandPaletteFilter::update_global(cx, |filter, _cx| {
295        filter.hide_namespace(Assistant::NAMESPACE);
296    });
297    Assistant::update_global(cx, |assistant, cx| {
298        let settings = AssistantSettings::get_global(cx);
299
300        assistant.set_enabled(settings.enabled, cx);
301    });
302    cx.observe_global::<SettingsStore>(|cx| {
303        Assistant::update_global(cx, |assistant, cx| {
304            let settings = AssistantSettings::get_global(cx);
305            assistant.set_enabled(settings.enabled, cx);
306        });
307    })
308    .detach();
309}
310
311fn register_slash_commands(cx: &mut AppContext) {
312    let slash_command_registry = SlashCommandRegistry::global(cx);
313    slash_command_registry.register_command(file_command::FileSlashCommand, true);
314    slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
315    slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
316    slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
317    slash_command_registry.register_command(search_command::SearchSlashCommand, true);
318    slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
319    slash_command_registry.register_command(default_command::DefaultSlashCommand, true);
320    slash_command_registry.register_command(term_command::TermSlashCommand, true);
321    slash_command_registry.register_command(now_command::NowSlashCommand, true);
322    slash_command_registry.register_command(diagnostics_command::DiagnosticsCommand, true);
323    slash_command_registry.register_command(rustdoc_command::RustdocSlashCommand, false);
324    slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
325}
326
327pub fn humanize_token_count(count: usize) -> String {
328    match count {
329        0..=999 => count.to_string(),
330        1000..=9999 => {
331            let thousands = count / 1000;
332            let hundreds = (count % 1000 + 50) / 100;
333            if hundreds == 0 {
334                format!("{}k", thousands)
335            } else if hundreds == 10 {
336                format!("{}k", thousands + 1)
337            } else {
338                format!("{}.{}k", thousands, hundreds)
339            }
340        }
341        _ => format!("{}k", (count + 500) / 1000),
342    }
343}
344
345#[cfg(test)]
346#[ctor::ctor]
347fn init_logger() {
348    if std::env::var("RUST_LOG").is_ok() {
349        env_logger::init();
350    }
351}