assistant.rs

  1pub mod assistant_panel;
  2pub mod assistant_settings;
  3mod codegen;
  4mod completion_provider;
  5mod prompts;
  6mod saved_conversation;
  7mod search;
  8mod slash_command;
  9mod streaming_diff;
 10
 11pub use assistant_panel::AssistantPanel;
 12
 13use assistant_settings::{AnthropicModel, AssistantSettings, OpenAiModel, ZedDotDevModel};
 14use client::{proto, Client};
 15use command_palette_hooks::CommandPaletteFilter;
 16pub(crate) use completion_provider::*;
 17use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
 18pub(crate) use saved_conversation::*;
 19use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
 20use serde::{Deserialize, Serialize};
 21use settings::{Settings, SettingsStore};
 22use std::{
 23    fmt::{self, Display},
 24    sync::Arc,
 25};
 26use util::paths::EMBEDDINGS_DIR;
 27
 28actions!(
 29    assistant,
 30    [
 31        Assist,
 32        Split,
 33        CycleMessageRole,
 34        QuoteSelection,
 35        ToggleFocus,
 36        ResetKey,
 37        InlineAssist,
 38        InsertActivePrompt,
 39        ToggleHistory,
 40        ApplyEdit,
 41        ConfirmCommand
 42    ]
 43);
 44
 45#[derive(
 46    Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
 47)]
 48struct MessageId(usize);
 49
 50#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 51#[serde(rename_all = "lowercase")]
 52pub enum Role {
 53    User,
 54    Assistant,
 55    System,
 56}
 57
 58impl Role {
 59    pub fn cycle(&mut self) {
 60        *self = match self {
 61            Role::User => Role::Assistant,
 62            Role::Assistant => Role::System,
 63            Role::System => Role::User,
 64        }
 65    }
 66}
 67
 68impl Display for Role {
 69    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
 70        match self {
 71            Role::User => write!(f, "user"),
 72            Role::Assistant => write!(f, "assistant"),
 73            Role::System => write!(f, "system"),
 74        }
 75    }
 76}
 77
 78#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
 79pub enum LanguageModel {
 80    ZedDotDev(ZedDotDevModel),
 81    OpenAi(OpenAiModel),
 82    Anthropic(AnthropicModel),
 83}
 84
 85impl Default for LanguageModel {
 86    fn default() -> Self {
 87        LanguageModel::ZedDotDev(ZedDotDevModel::default())
 88    }
 89}
 90
 91impl LanguageModel {
 92    pub fn telemetry_id(&self) -> String {
 93        match self {
 94            LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
 95            LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
 96            LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
 97        }
 98    }
 99
100    pub fn display_name(&self) -> String {
101        match self {
102            LanguageModel::OpenAi(model) => model.display_name().into(),
103            LanguageModel::Anthropic(model) => model.display_name().into(),
104            LanguageModel::ZedDotDev(model) => model.display_name().into(),
105        }
106    }
107
108    pub fn max_token_count(&self) -> usize {
109        match self {
110            LanguageModel::OpenAi(model) => model.max_token_count(),
111            LanguageModel::Anthropic(model) => model.max_token_count(),
112            LanguageModel::ZedDotDev(model) => model.max_token_count(),
113        }
114    }
115
116    pub fn id(&self) -> &str {
117        match self {
118            LanguageModel::OpenAi(model) => model.id(),
119            LanguageModel::Anthropic(model) => model.id(),
120            LanguageModel::ZedDotDev(model) => model.id(),
121        }
122    }
123}
124
125#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
126pub struct LanguageModelRequestMessage {
127    pub role: Role,
128    pub content: String,
129}
130
131impl LanguageModelRequestMessage {
132    pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
133        proto::LanguageModelRequestMessage {
134            role: match self.role {
135                Role::User => proto::LanguageModelRole::LanguageModelUser,
136                Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
137                Role::System => proto::LanguageModelRole::LanguageModelSystem,
138            } as i32,
139            content: self.content.clone(),
140            tool_calls: Vec::new(),
141            tool_call_id: None,
142        }
143    }
144}
145
146#[derive(Debug, Default, Serialize)]
147pub struct LanguageModelRequest {
148    pub model: LanguageModel,
149    pub messages: Vec<LanguageModelRequestMessage>,
150    pub stop: Vec<String>,
151    pub temperature: f32,
152}
153
154impl LanguageModelRequest {
155    pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
156        proto::CompleteWithLanguageModel {
157            model: self.model.id().to_string(),
158            messages: self.messages.iter().map(|m| m.to_proto()).collect(),
159            stop: self.stop.clone(),
160            temperature: self.temperature,
161            tool_choice: None,
162            tools: Vec::new(),
163        }
164    }
165}
166
167#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
168pub struct LanguageModelResponseMessage {
169    pub role: Option<Role>,
170    pub content: Option<String>,
171}
172
173#[derive(Deserialize, Debug)]
174pub struct LanguageModelUsage {
175    pub prompt_tokens: u32,
176    pub completion_tokens: u32,
177    pub total_tokens: u32,
178}
179
180#[derive(Deserialize, Debug)]
181pub struct LanguageModelChoiceDelta {
182    pub index: u32,
183    pub delta: LanguageModelResponseMessage,
184    pub finish_reason: Option<String>,
185}
186
187#[derive(Clone, Debug, Serialize, Deserialize)]
188struct MessageMetadata {
189    role: Role,
190    status: MessageStatus,
191}
192
193#[derive(Clone, Debug, Serialize, Deserialize)]
194enum MessageStatus {
195    Pending,
196    Done,
197    Error(SharedString),
198}
199
200/// The state pertaining to the Assistant.
201#[derive(Default)]
202struct Assistant {
203    /// Whether the Assistant is enabled.
204    enabled: bool,
205}
206
207impl Global for Assistant {}
208
209impl Assistant {
210    const NAMESPACE: &'static str = "assistant";
211
212    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
213        if self.enabled == enabled {
214            return;
215        }
216
217        self.enabled = enabled;
218
219        if !enabled {
220            CommandPaletteFilter::update_global(cx, |filter, _cx| {
221                filter.hide_namespace(Self::NAMESPACE);
222            });
223
224            return;
225        }
226
227        CommandPaletteFilter::update_global(cx, |filter, _cx| {
228            filter.show_namespace(Self::NAMESPACE);
229        });
230    }
231}
232
233pub fn init(client: Arc<Client>, cx: &mut AppContext) {
234    cx.set_global(Assistant::default());
235    AssistantSettings::register(cx);
236
237    cx.spawn(|mut cx| {
238        let client = client.clone();
239        async move {
240            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
241            let semantic_index = SemanticIndex::new(
242                EMBEDDINGS_DIR.join("semantic-index-db.0.mdb"),
243                Arc::new(embedding_provider),
244                &mut cx,
245            )
246            .await?;
247            cx.update(|cx| cx.set_global(semantic_index))
248        }
249    })
250    .detach();
251    completion_provider::init(client, cx);
252    assistant_slash_command::init(cx);
253    assistant_panel::init(cx);
254
255    CommandPaletteFilter::update_global(cx, |filter, _cx| {
256        filter.hide_namespace(Assistant::NAMESPACE);
257    });
258    Assistant::update_global(cx, |assistant, cx| {
259        let settings = AssistantSettings::get_global(cx);
260
261        assistant.set_enabled(settings.enabled, cx);
262    });
263    cx.observe_global::<SettingsStore>(|cx| {
264        Assistant::update_global(cx, |assistant, cx| {
265            let settings = AssistantSettings::get_global(cx);
266
267            assistant.set_enabled(settings.enabled, cx);
268        });
269    })
270    .detach();
271}
272
273#[cfg(test)]
274#[ctor::ctor]
275fn init_logger() {
276    if std::env::var("RUST_LOG").is_ok() {
277        env_logger::init();
278    }
279}