assistant.rs

  1pub mod assistant_panel;
  2pub mod assistant_settings;
  3mod codegen;
  4mod completion_provider;
  5mod prompts;
  6mod saved_conversation;
  7mod streaming_diff;
  8
  9mod embedded_scope;
 10
 11pub use assistant_panel::AssistantPanel;
 12use assistant_settings::{AssistantSettings, OpenAiModel, ZedDotDevModel};
 13use chrono::{DateTime, Local};
 14use client::{proto, Client};
 15use command_palette_hooks::CommandPaletteFilter;
 16pub(crate) use completion_provider::*;
 17use gpui::{actions, AppContext, BorrowAppContext, Global, SharedString};
 18pub(crate) use saved_conversation::*;
 19use serde::{Deserialize, Serialize};
 20use settings::{Settings, SettingsStore};
 21use std::{
 22    fmt::{self, Display},
 23    sync::Arc,
 24};
 25
 26actions!(
 27    assistant,
 28    [
 29        NewConversation,
 30        Assist,
 31        Split,
 32        CycleMessageRole,
 33        QuoteSelection,
 34        ToggleFocus,
 35        ResetKey,
 36        InlineAssist,
 37        ToggleIncludeConversation,
 38    ]
 39);
 40
 41#[derive(
 42    Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
 43)]
 44struct MessageId(usize);
 45
 46#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 47#[serde(rename_all = "lowercase")]
 48pub enum Role {
 49    User,
 50    Assistant,
 51    System,
 52}
 53
 54impl Role {
 55    pub fn cycle(&mut self) {
 56        *self = match self {
 57            Role::User => Role::Assistant,
 58            Role::Assistant => Role::System,
 59            Role::System => Role::User,
 60        }
 61    }
 62}
 63
 64impl Display for Role {
 65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
 66        match self {
 67            Role::User => write!(f, "user"),
 68            Role::Assistant => write!(f, "assistant"),
 69            Role::System => write!(f, "system"),
 70        }
 71    }
 72}
 73
 74#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
 75pub enum LanguageModel {
 76    ZedDotDev(ZedDotDevModel),
 77    OpenAi(OpenAiModel),
 78}
 79
 80impl Default for LanguageModel {
 81    fn default() -> Self {
 82        LanguageModel::ZedDotDev(ZedDotDevModel::default())
 83    }
 84}
 85
 86impl LanguageModel {
 87    pub fn telemetry_id(&self) -> String {
 88        match self {
 89            LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
 90            LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
 91        }
 92    }
 93
 94    pub fn display_name(&self) -> String {
 95        match self {
 96            LanguageModel::OpenAi(model) => format!("openai/{}", model.display_name()),
 97            LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.display_name()),
 98        }
 99    }
100
101    pub fn max_token_count(&self) -> usize {
102        match self {
103            LanguageModel::OpenAi(model) => model.max_token_count(),
104            LanguageModel::ZedDotDev(model) => model.max_token_count(),
105        }
106    }
107
108    pub fn id(&self) -> &str {
109        match self {
110            LanguageModel::OpenAi(model) => model.id(),
111            LanguageModel::ZedDotDev(model) => model.id(),
112        }
113    }
114}
115
116#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
117pub struct LanguageModelRequestMessage {
118    pub role: Role,
119    pub content: String,
120}
121
122impl LanguageModelRequestMessage {
123    pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
124        proto::LanguageModelRequestMessage {
125            role: match self.role {
126                Role::User => proto::LanguageModelRole::LanguageModelUser,
127                Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
128                Role::System => proto::LanguageModelRole::LanguageModelSystem,
129            } as i32,
130            content: self.content.clone(),
131        }
132    }
133}
134
135#[derive(Debug, Default, Serialize)]
136pub struct LanguageModelRequest {
137    pub model: LanguageModel,
138    pub messages: Vec<LanguageModelRequestMessage>,
139    pub stop: Vec<String>,
140    pub temperature: f32,
141}
142
143impl LanguageModelRequest {
144    pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
145        proto::CompleteWithLanguageModel {
146            model: self.model.id().to_string(),
147            messages: self.messages.iter().map(|m| m.to_proto()).collect(),
148            stop: self.stop.clone(),
149            temperature: self.temperature,
150        }
151    }
152}
153
154#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
155pub struct LanguageModelResponseMessage {
156    pub role: Option<Role>,
157    pub content: Option<String>,
158}
159
160#[derive(Deserialize, Debug)]
161pub struct LanguageModelUsage {
162    pub prompt_tokens: u32,
163    pub completion_tokens: u32,
164    pub total_tokens: u32,
165}
166
167#[derive(Deserialize, Debug)]
168pub struct LanguageModelChoiceDelta {
169    pub index: u32,
170    pub delta: LanguageModelResponseMessage,
171    pub finish_reason: Option<String>,
172}
173
174#[derive(Clone, Debug, Serialize, Deserialize)]
175struct MessageMetadata {
176    role: Role,
177    sent_at: DateTime<Local>,
178    status: MessageStatus,
179}
180
181#[derive(Clone, Debug, Serialize, Deserialize)]
182enum MessageStatus {
183    Pending,
184    Done,
185    Error(SharedString),
186}
187
188/// The state pertaining to the Assistant.
189#[derive(Default)]
190struct Assistant {
191    /// Whether the Assistant is enabled.
192    enabled: bool,
193}
194
195impl Global for Assistant {}
196
197impl Assistant {
198    const NAMESPACE: &'static str = "assistant";
199
200    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
201        if self.enabled == enabled {
202            return;
203        }
204
205        self.enabled = enabled;
206
207        if !enabled {
208            CommandPaletteFilter::update_global(cx, |filter, _cx| {
209                filter.hide_namespace(Self::NAMESPACE);
210            });
211
212            return;
213        }
214
215        CommandPaletteFilter::update_global(cx, |filter, _cx| {
216            filter.show_namespace(Self::NAMESPACE);
217        });
218    }
219}
220
221pub fn init(client: Arc<Client>, cx: &mut AppContext) {
222    cx.set_global(Assistant::default());
223    AssistantSettings::register(cx);
224    completion_provider::init(client, cx);
225    assistant_panel::init(cx);
226
227    CommandPaletteFilter::update_global(cx, |filter, _cx| {
228        filter.hide_namespace(Assistant::NAMESPACE);
229    });
230    cx.update_global(|assistant: &mut Assistant, cx: &mut AppContext| {
231        let settings = AssistantSettings::get_global(cx);
232
233        assistant.set_enabled(settings.enabled, cx);
234    });
235    cx.observe_global::<SettingsStore>(|cx| {
236        cx.update_global(|assistant: &mut Assistant, cx: &mut AppContext| {
237            let settings = AssistantSettings::get_global(cx);
238
239            assistant.set_enabled(settings.enabled, cx);
240        });
241    })
242    .detach();
243}
244
245#[cfg(test)]
246#[ctor::ctor]
247fn init_logger() {
248    if std::env::var("RUST_LOG").is_ok() {
249        env_logger::init();
250    }
251}