assistant.rs

  1pub mod assistant_panel;
  2pub mod assistant_settings;
  3mod codegen;
  4mod completion_provider;
  5mod model_selector;
  6mod prompts;
  7mod saved_conversation;
  8mod search;
  9mod slash_command;
 10mod streaming_diff;
 11
 12pub use assistant_panel::AssistantPanel;
 13
 14use assistant_settings::{AnthropicModel, AssistantSettings, OpenAiModel, ZedDotDevModel};
 15use client::{proto, Client};
 16use command_palette_hooks::CommandPaletteFilter;
 17pub(crate) use completion_provider::*;
 18use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
 19pub(crate) use model_selector::*;
 20pub(crate) use saved_conversation::*;
 21use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
 22use serde::{Deserialize, Serialize};
 23use settings::{Settings, SettingsStore};
 24use std::{
 25    fmt::{self, Display},
 26    sync::Arc,
 27};
 28use util::paths::EMBEDDINGS_DIR;
 29
 30actions!(
 31    assistant,
 32    [
 33        Assist,
 34        Split,
 35        CycleMessageRole,
 36        QuoteSelection,
 37        ToggleFocus,
 38        ResetKey,
 39        InlineAssist,
 40        InsertActivePrompt,
 41        ToggleHistory,
 42        ApplyEdit,
 43        ConfirmCommand,
 44        ToggleModelSelector
 45    ]
 46);
 47
 48#[derive(
 49    Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
 50)]
 51struct MessageId(usize);
 52
 53#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 54#[serde(rename_all = "lowercase")]
 55pub enum Role {
 56    User,
 57    Assistant,
 58    System,
 59}
 60
 61impl Role {
 62    pub fn cycle(&mut self) {
 63        *self = match self {
 64            Role::User => Role::Assistant,
 65            Role::Assistant => Role::System,
 66            Role::System => Role::User,
 67        }
 68    }
 69}
 70
 71impl Display for Role {
 72    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
 73        match self {
 74            Role::User => write!(f, "user"),
 75            Role::Assistant => write!(f, "assistant"),
 76            Role::System => write!(f, "system"),
 77        }
 78    }
 79}
 80
 81#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
 82pub enum LanguageModel {
 83    ZedDotDev(ZedDotDevModel),
 84    OpenAi(OpenAiModel),
 85    Anthropic(AnthropicModel),
 86}
 87
 88impl Default for LanguageModel {
 89    fn default() -> Self {
 90        LanguageModel::ZedDotDev(ZedDotDevModel::default())
 91    }
 92}
 93
 94impl LanguageModel {
 95    pub fn telemetry_id(&self) -> String {
 96        match self {
 97            LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
 98            LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
 99            LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
100        }
101    }
102
103    pub fn display_name(&self) -> String {
104        match self {
105            LanguageModel::OpenAi(model) => model.display_name().into(),
106            LanguageModel::Anthropic(model) => model.display_name().into(),
107            LanguageModel::ZedDotDev(model) => model.display_name().into(),
108        }
109    }
110
111    pub fn max_token_count(&self) -> usize {
112        match self {
113            LanguageModel::OpenAi(model) => model.max_token_count(),
114            LanguageModel::Anthropic(model) => model.max_token_count(),
115            LanguageModel::ZedDotDev(model) => model.max_token_count(),
116        }
117    }
118
119    pub fn id(&self) -> &str {
120        match self {
121            LanguageModel::OpenAi(model) => model.id(),
122            LanguageModel::Anthropic(model) => model.id(),
123            LanguageModel::ZedDotDev(model) => model.id(),
124        }
125    }
126}
127
128#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
129pub struct LanguageModelRequestMessage {
130    pub role: Role,
131    pub content: String,
132}
133
134impl LanguageModelRequestMessage {
135    pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
136        proto::LanguageModelRequestMessage {
137            role: match self.role {
138                Role::User => proto::LanguageModelRole::LanguageModelUser,
139                Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
140                Role::System => proto::LanguageModelRole::LanguageModelSystem,
141            } as i32,
142            content: self.content.clone(),
143            tool_calls: Vec::new(),
144            tool_call_id: None,
145        }
146    }
147}
148
149#[derive(Debug, Default, Serialize)]
150pub struct LanguageModelRequest {
151    pub model: LanguageModel,
152    pub messages: Vec<LanguageModelRequestMessage>,
153    pub stop: Vec<String>,
154    pub temperature: f32,
155}
156
157impl LanguageModelRequest {
158    pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
159        proto::CompleteWithLanguageModel {
160            model: self.model.id().to_string(),
161            messages: self.messages.iter().map(|m| m.to_proto()).collect(),
162            stop: self.stop.clone(),
163            temperature: self.temperature,
164            tool_choice: None,
165            tools: Vec::new(),
166        }
167    }
168}
169
170#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
171pub struct LanguageModelResponseMessage {
172    pub role: Option<Role>,
173    pub content: Option<String>,
174}
175
176#[derive(Deserialize, Debug)]
177pub struct LanguageModelUsage {
178    pub prompt_tokens: u32,
179    pub completion_tokens: u32,
180    pub total_tokens: u32,
181}
182
183#[derive(Deserialize, Debug)]
184pub struct LanguageModelChoiceDelta {
185    pub index: u32,
186    pub delta: LanguageModelResponseMessage,
187    pub finish_reason: Option<String>,
188}
189
190#[derive(Clone, Debug, Serialize, Deserialize)]
191struct MessageMetadata {
192    role: Role,
193    status: MessageStatus,
194}
195
196#[derive(Clone, Debug, Serialize, Deserialize)]
197enum MessageStatus {
198    Pending,
199    Done,
200    Error(SharedString),
201}
202
203/// The state pertaining to the Assistant.
204#[derive(Default)]
205struct Assistant {
206    /// Whether the Assistant is enabled.
207    enabled: bool,
208}
209
210impl Global for Assistant {}
211
212impl Assistant {
213    const NAMESPACE: &'static str = "assistant";
214
215    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
216        if self.enabled == enabled {
217            return;
218        }
219
220        self.enabled = enabled;
221
222        if !enabled {
223            CommandPaletteFilter::update_global(cx, |filter, _cx| {
224                filter.hide_namespace(Self::NAMESPACE);
225            });
226
227            return;
228        }
229
230        CommandPaletteFilter::update_global(cx, |filter, _cx| {
231            filter.show_namespace(Self::NAMESPACE);
232        });
233    }
234}
235
236pub fn init(client: Arc<Client>, cx: &mut AppContext) {
237    cx.set_global(Assistant::default());
238    AssistantSettings::register(cx);
239
240    cx.spawn(|mut cx| {
241        let client = client.clone();
242        async move {
243            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
244            let semantic_index = SemanticIndex::new(
245                EMBEDDINGS_DIR.join("semantic-index-db.0.mdb"),
246                Arc::new(embedding_provider),
247                &mut cx,
248            )
249            .await?;
250            cx.update(|cx| cx.set_global(semantic_index))
251        }
252    })
253    .detach();
254    completion_provider::init(client, cx);
255    assistant_slash_command::init(cx);
256    assistant_panel::init(cx);
257
258    CommandPaletteFilter::update_global(cx, |filter, _cx| {
259        filter.hide_namespace(Assistant::NAMESPACE);
260    });
261    Assistant::update_global(cx, |assistant, cx| {
262        let settings = AssistantSettings::get_global(cx);
263
264        assistant.set_enabled(settings.enabled, cx);
265    });
266    cx.observe_global::<SettingsStore>(|cx| {
267        Assistant::update_global(cx, |assistant, cx| {
268            let settings = AssistantSettings::get_global(cx);
269
270            assistant.set_enabled(settings.enabled, cx);
271        });
272    })
273    .detach();
274}
275
276#[cfg(test)]
277#[ctor::ctor]
278fn init_logger() {
279    if std::env::var("RUST_LOG").is_ok() {
280        env_logger::init();
281    }
282}