assistant.rs

  1pub mod assistant_panel;
  2pub mod assistant_settings;
  3mod completion_provider;
  4mod context_store;
  5mod inline_assistant;
  6mod model_selector;
  7mod prompt_library;
  8mod prompts;
  9mod search;
 10mod slash_command;
 11mod streaming_diff;
 12
 13pub use assistant_panel::AssistantPanel;
 14
 15use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel};
 16use assistant_slash_command::SlashCommandRegistry;
 17use client::{proto, Client};
 18use command_palette_hooks::CommandPaletteFilter;
 19pub(crate) use completion_provider::*;
 20pub(crate) use context_store::*;
 21use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
 22pub(crate) use inline_assistant::*;
 23pub(crate) use model_selector::*;
 24use rustdoc::RustdocStore;
 25use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
 26use serde::{Deserialize, Serialize};
 27use settings::{Settings, SettingsStore};
 28use slash_command::{
 29    active_command, default_command, fetch_command, file_command, now_command, project_command,
 30    prompt_command, rustdoc_command, search_command, tabs_command,
 31};
 32use std::{
 33    fmt::{self, Display},
 34    sync::Arc,
 35};
 36pub(crate) use streaming_diff::*;
 37use util::paths::EMBEDDINGS_DIR;
 38
 39actions!(
 40    assistant,
 41    [
 42        Assist,
 43        Split,
 44        CycleMessageRole,
 45        QuoteSelection,
 46        ToggleFocus,
 47        ResetKey,
 48        InlineAssist,
 49        InsertActivePrompt,
 50        ToggleHistory,
 51        ApplyEdit,
 52        ConfirmCommand,
 53        ToggleModelSelector
 54    ]
 55);
 56
 57#[derive(
 58    Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
 59)]
 60struct MessageId(usize);
 61
 62#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 63#[serde(rename_all = "lowercase")]
 64pub enum Role {
 65    User,
 66    Assistant,
 67    System,
 68}
 69
 70impl Role {
 71    pub fn cycle(&mut self) {
 72        *self = match self {
 73            Role::User => Role::Assistant,
 74            Role::Assistant => Role::System,
 75            Role::System => Role::User,
 76        }
 77    }
 78}
 79
 80impl Display for Role {
 81    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
 82        match self {
 83            Role::User => write!(f, "user"),
 84            Role::Assistant => write!(f, "assistant"),
 85            Role::System => write!(f, "system"),
 86        }
 87    }
 88}
 89
 90#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
 91pub enum LanguageModel {
 92    Cloud(CloudModel),
 93    OpenAi(OpenAiModel),
 94    Anthropic(AnthropicModel),
 95    Ollama(OllamaModel),
 96}
 97
 98impl Default for LanguageModel {
 99    fn default() -> Self {
100        LanguageModel::Cloud(CloudModel::default())
101    }
102}
103
104impl LanguageModel {
105    pub fn telemetry_id(&self) -> String {
106        match self {
107            LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
108            LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
109            LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
110            LanguageModel::Ollama(model) => format!("ollama/{}", model.id()),
111        }
112    }
113
114    pub fn display_name(&self) -> String {
115        match self {
116            LanguageModel::OpenAi(model) => model.display_name().into(),
117            LanguageModel::Anthropic(model) => model.display_name().into(),
118            LanguageModel::Cloud(model) => model.display_name().into(),
119            LanguageModel::Ollama(model) => model.display_name().into(),
120        }
121    }
122
123    pub fn max_token_count(&self) -> usize {
124        match self {
125            LanguageModel::OpenAi(model) => model.max_token_count(),
126            LanguageModel::Anthropic(model) => model.max_token_count(),
127            LanguageModel::Cloud(model) => model.max_token_count(),
128            LanguageModel::Ollama(model) => model.max_token_count(),
129        }
130    }
131
132    pub fn id(&self) -> &str {
133        match self {
134            LanguageModel::OpenAi(model) => model.id(),
135            LanguageModel::Anthropic(model) => model.id(),
136            LanguageModel::Cloud(model) => model.id(),
137            LanguageModel::Ollama(model) => model.id(),
138        }
139    }
140}
141
142#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
143pub struct LanguageModelRequestMessage {
144    pub role: Role,
145    pub content: String,
146}
147
148impl LanguageModelRequestMessage {
149    pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
150        proto::LanguageModelRequestMessage {
151            role: match self.role {
152                Role::User => proto::LanguageModelRole::LanguageModelUser,
153                Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
154                Role::System => proto::LanguageModelRole::LanguageModelSystem,
155            } as i32,
156            content: self.content.clone(),
157            tool_calls: Vec::new(),
158            tool_call_id: None,
159        }
160    }
161}
162
163#[derive(Debug, Default, Serialize)]
164pub struct LanguageModelRequest {
165    pub model: LanguageModel,
166    pub messages: Vec<LanguageModelRequestMessage>,
167    pub stop: Vec<String>,
168    pub temperature: f32,
169}
170
171impl LanguageModelRequest {
172    pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
173        proto::CompleteWithLanguageModel {
174            model: self.model.id().to_string(),
175            messages: self.messages.iter().map(|m| m.to_proto()).collect(),
176            stop: self.stop.clone(),
177            temperature: self.temperature,
178            tool_choice: None,
179            tools: Vec::new(),
180        }
181    }
182
183    /// Before we send the request to the server, we can perform fixups on it appropriate to the model.
184    pub fn preprocess(&mut self) {
185        match &self.model {
186            LanguageModel::OpenAi(_) => {}
187            LanguageModel::Anthropic(_) => {}
188            LanguageModel::Ollama(_) => {}
189            LanguageModel::Cloud(model) => match model {
190                CloudModel::Claude3Opus | CloudModel::Claude3Sonnet | CloudModel::Claude3Haiku => {
191                    preprocess_anthropic_request(self);
192                }
193                _ => {}
194            },
195        }
196    }
197}
198
199#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
200pub struct LanguageModelResponseMessage {
201    pub role: Option<Role>,
202    pub content: Option<String>,
203}
204
205#[derive(Deserialize, Debug)]
206pub struct LanguageModelUsage {
207    pub prompt_tokens: u32,
208    pub completion_tokens: u32,
209    pub total_tokens: u32,
210}
211
212#[derive(Deserialize, Debug)]
213pub struct LanguageModelChoiceDelta {
214    pub index: u32,
215    pub delta: LanguageModelResponseMessage,
216    pub finish_reason: Option<String>,
217}
218
219#[derive(Clone, Debug, Serialize, Deserialize)]
220struct MessageMetadata {
221    role: Role,
222    status: MessageStatus,
223}
224
225#[derive(Clone, Debug, Serialize, Deserialize)]
226enum MessageStatus {
227    Pending,
228    Done,
229    Error(SharedString),
230}
231
232/// The state pertaining to the Assistant.
233#[derive(Default)]
234struct Assistant {
235    /// Whether the Assistant is enabled.
236    enabled: bool,
237}
238
239impl Global for Assistant {}
240
241impl Assistant {
242    const NAMESPACE: &'static str = "assistant";
243
244    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
245        if self.enabled == enabled {
246            return;
247        }
248
249        self.enabled = enabled;
250
251        if !enabled {
252            CommandPaletteFilter::update_global(cx, |filter, _cx| {
253                filter.hide_namespace(Self::NAMESPACE);
254            });
255
256            return;
257        }
258
259        CommandPaletteFilter::update_global(cx, |filter, _cx| {
260            filter.show_namespace(Self::NAMESPACE);
261        });
262    }
263}
264
265pub fn init(client: Arc<Client>, cx: &mut AppContext) {
266    cx.set_global(Assistant::default());
267    AssistantSettings::register(cx);
268
269    cx.spawn(|mut cx| {
270        let client = client.clone();
271        async move {
272            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
273            let semantic_index = SemanticIndex::new(
274                EMBEDDINGS_DIR.join("semantic-index-db.0.mdb"),
275                Arc::new(embedding_provider),
276                &mut cx,
277            )
278            .await?;
279            cx.update(|cx| cx.set_global(semantic_index))
280        }
281    })
282    .detach();
283
284    prompt_library::init(cx);
285    completion_provider::init(client.clone(), cx);
286    assistant_slash_command::init(cx);
287    register_slash_commands(cx);
288    assistant_panel::init(cx);
289    inline_assistant::init(client.telemetry().clone(), cx);
290    RustdocStore::init_global(cx);
291
292    CommandPaletteFilter::update_global(cx, |filter, _cx| {
293        filter.hide_namespace(Assistant::NAMESPACE);
294    });
295    Assistant::update_global(cx, |assistant, cx| {
296        let settings = AssistantSettings::get_global(cx);
297
298        assistant.set_enabled(settings.enabled, cx);
299    });
300    cx.observe_global::<SettingsStore>(|cx| {
301        Assistant::update_global(cx, |assistant, cx| {
302            let settings = AssistantSettings::get_global(cx);
303            assistant.set_enabled(settings.enabled, cx);
304        });
305    })
306    .detach();
307}
308
309fn register_slash_commands(cx: &mut AppContext) {
310    let slash_command_registry = SlashCommandRegistry::global(cx);
311    slash_command_registry.register_command(file_command::FileSlashCommand, true);
312    slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
313    slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
314    slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
315    slash_command_registry.register_command(search_command::SearchSlashCommand, true);
316    slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
317    slash_command_registry.register_command(default_command::DefaultSlashCommand, true);
318    slash_command_registry.register_command(now_command::NowSlashCommand, true);
319    slash_command_registry.register_command(rustdoc_command::RustdocSlashCommand, false);
320    slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
321}
322
323#[cfg(test)]
324#[ctor::ctor]
325fn init_logger() {
326    if std::env::var("RUST_LOG").is_ok() {
327        env_logger::init();
328    }
329}