assistant.rs

  1pub mod assistant_panel;
  2pub mod assistant_settings;
  3mod codegen;
  4mod completion_provider;
  5mod prompts;
  6mod saved_conversation;
  7mod streaming_diff;
  8
  9mod embedded_scope;
 10
 11pub use assistant_panel::AssistantPanel;
 12use assistant_settings::{AssistantSettings, OpenAiModel, ZedDotDevModel};
 13use chrono::{DateTime, Local};
 14use client::{proto, Client};
 15use command_palette_hooks::CommandPaletteFilter;
 16pub(crate) use completion_provider::*;
 17use gpui::{actions, AppContext, BorrowAppContext, Global, SharedString};
 18pub(crate) use saved_conversation::*;
 19use serde::{Deserialize, Serialize};
 20use settings::{Settings, SettingsStore};
 21use std::{
 22    fmt::{self, Display},
 23    sync::Arc,
 24};
 25
 26actions!(
 27    assistant,
 28    [
 29        NewConversation,
 30        Assist,
 31        Split,
 32        CycleMessageRole,
 33        QuoteSelection,
 34        ToggleFocus,
 35        ResetKey,
 36        InlineAssist,
 37        ToggleIncludeConversation,
 38    ]
 39);
 40
 41#[derive(
 42    Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
 43)]
 44struct MessageId(usize);
 45
 46#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 47#[serde(rename_all = "lowercase")]
 48pub enum Role {
 49    User,
 50    Assistant,
 51    System,
 52}
 53
 54impl Role {
 55    pub fn cycle(&mut self) {
 56        *self = match self {
 57            Role::User => Role::Assistant,
 58            Role::Assistant => Role::System,
 59            Role::System => Role::User,
 60        }
 61    }
 62}
 63
 64impl Display for Role {
 65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
 66        match self {
 67            Role::User => write!(f, "user"),
 68            Role::Assistant => write!(f, "assistant"),
 69            Role::System => write!(f, "system"),
 70        }
 71    }
 72}
 73
 74#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
 75pub enum LanguageModel {
 76    ZedDotDev(ZedDotDevModel),
 77    OpenAi(OpenAiModel),
 78}
 79
 80impl Default for LanguageModel {
 81    fn default() -> Self {
 82        LanguageModel::ZedDotDev(ZedDotDevModel::default())
 83    }
 84}
 85
 86impl LanguageModel {
 87    pub fn telemetry_id(&self) -> String {
 88        match self {
 89            LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
 90            LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
 91        }
 92    }
 93
 94    pub fn display_name(&self) -> String {
 95        match self {
 96            LanguageModel::OpenAi(model) => format!("openai/{}", model.display_name()),
 97            LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.display_name()),
 98        }
 99    }
100
101    pub fn max_token_count(&self) -> usize {
102        match self {
103            LanguageModel::OpenAi(model) => model.max_token_count(),
104            LanguageModel::ZedDotDev(model) => model.max_token_count(),
105        }
106    }
107
108    pub fn id(&self) -> &str {
109        match self {
110            LanguageModel::OpenAi(model) => model.id(),
111            LanguageModel::ZedDotDev(model) => model.id(),
112        }
113    }
114}
115
116#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
117pub struct LanguageModelRequestMessage {
118    pub role: Role,
119    pub content: String,
120}
121
122impl LanguageModelRequestMessage {
123    pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
124        proto::LanguageModelRequestMessage {
125            role: match self.role {
126                Role::User => proto::LanguageModelRole::LanguageModelUser,
127                Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
128                Role::System => proto::LanguageModelRole::LanguageModelSystem,
129            } as i32,
130            content: self.content.clone(),
131            tool_calls: Vec::new(),
132            tool_call_id: None,
133        }
134    }
135}
136
137#[derive(Debug, Default, Serialize)]
138pub struct LanguageModelRequest {
139    pub model: LanguageModel,
140    pub messages: Vec<LanguageModelRequestMessage>,
141    pub stop: Vec<String>,
142    pub temperature: f32,
143}
144
145impl LanguageModelRequest {
146    pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
147        proto::CompleteWithLanguageModel {
148            model: self.model.id().to_string(),
149            messages: self.messages.iter().map(|m| m.to_proto()).collect(),
150            stop: self.stop.clone(),
151            temperature: self.temperature,
152            tool_choice: None,
153            tools: Vec::new(),
154        }
155    }
156}
157
158#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
159pub struct LanguageModelResponseMessage {
160    pub role: Option<Role>,
161    pub content: Option<String>,
162}
163
164#[derive(Deserialize, Debug)]
165pub struct LanguageModelUsage {
166    pub prompt_tokens: u32,
167    pub completion_tokens: u32,
168    pub total_tokens: u32,
169}
170
171#[derive(Deserialize, Debug)]
172pub struct LanguageModelChoiceDelta {
173    pub index: u32,
174    pub delta: LanguageModelResponseMessage,
175    pub finish_reason: Option<String>,
176}
177
178#[derive(Clone, Debug, Serialize, Deserialize)]
179struct MessageMetadata {
180    role: Role,
181    sent_at: DateTime<Local>,
182    status: MessageStatus,
183}
184
185#[derive(Clone, Debug, Serialize, Deserialize)]
186enum MessageStatus {
187    Pending,
188    Done,
189    Error(SharedString),
190}
191
192/// The state pertaining to the Assistant.
193#[derive(Default)]
194struct Assistant {
195    /// Whether the Assistant is enabled.
196    enabled: bool,
197}
198
199impl Global for Assistant {}
200
201impl Assistant {
202    const NAMESPACE: &'static str = "assistant";
203
204    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
205        if self.enabled == enabled {
206            return;
207        }
208
209        self.enabled = enabled;
210
211        if !enabled {
212            CommandPaletteFilter::update_global(cx, |filter, _cx| {
213                filter.hide_namespace(Self::NAMESPACE);
214            });
215
216            return;
217        }
218
219        CommandPaletteFilter::update_global(cx, |filter, _cx| {
220            filter.show_namespace(Self::NAMESPACE);
221        });
222    }
223}
224
225pub fn init(client: Arc<Client>, cx: &mut AppContext) {
226    cx.set_global(Assistant::default());
227    AssistantSettings::register(cx);
228    completion_provider::init(client, cx);
229    assistant_panel::init(cx);
230
231    CommandPaletteFilter::update_global(cx, |filter, _cx| {
232        filter.hide_namespace(Assistant::NAMESPACE);
233    });
234    cx.update_global(|assistant: &mut Assistant, cx: &mut AppContext| {
235        let settings = AssistantSettings::get_global(cx);
236
237        assistant.set_enabled(settings.enabled, cx);
238    });
239    cx.observe_global::<SettingsStore>(|cx| {
240        cx.update_global(|assistant: &mut Assistant, cx: &mut AppContext| {
241            let settings = AssistantSettings::get_global(cx);
242
243            assistant.set_enabled(settings.enabled, cx);
244        });
245    })
246    .detach();
247}
248
249#[cfg(test)]
250#[ctor::ctor]
251fn init_logger() {
252    if std::env::var("RUST_LOG").is_ok() {
253        env_logger::init();
254    }
255}