assistant.rs

  1mod ambient_context;
  2pub mod assistant_panel;
  3pub mod assistant_settings;
  4mod codegen;
  5mod completion_provider;
  6mod prompts;
  7mod saved_conversation;
  8mod streaming_diff;
  9
 10pub use assistant_panel::AssistantPanel;
 11use assistant_settings::{AnthropicModel, AssistantSettings, OpenAiModel, ZedDotDevModel};
 12use client::{proto, Client};
 13use command_palette_hooks::CommandPaletteFilter;
 14pub(crate) use completion_provider::*;
 15use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
 16pub(crate) use saved_conversation::*;
 17use serde::{Deserialize, Serialize};
 18use settings::{Settings, SettingsStore};
 19use std::{
 20    fmt::{self, Display},
 21    sync::Arc,
 22};
 23
 24actions!(
 25    assistant,
 26    [
 27        Assist,
 28        Split,
 29        CycleMessageRole,
 30        QuoteSelection,
 31        ToggleFocus,
 32        ResetKey,
 33        InlineAssist,
 34        ToggleIncludeConversation,
 35        ToggleHistory,
 36    ]
 37);
 38
 39#[derive(
 40    Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
 41)]
 42struct MessageId(usize);
 43
 44#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 45#[serde(rename_all = "lowercase")]
 46pub enum Role {
 47    User,
 48    Assistant,
 49    System,
 50}
 51
 52impl Role {
 53    pub fn cycle(&mut self) {
 54        *self = match self {
 55            Role::User => Role::Assistant,
 56            Role::Assistant => Role::System,
 57            Role::System => Role::User,
 58        }
 59    }
 60}
 61
 62impl Display for Role {
 63    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
 64        match self {
 65            Role::User => write!(f, "user"),
 66            Role::Assistant => write!(f, "assistant"),
 67            Role::System => write!(f, "system"),
 68        }
 69    }
 70}
 71
 72#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
 73pub enum LanguageModel {
 74    ZedDotDev(ZedDotDevModel),
 75    OpenAi(OpenAiModel),
 76    Anthropic(AnthropicModel),
 77}
 78
 79impl Default for LanguageModel {
 80    fn default() -> Self {
 81        LanguageModel::ZedDotDev(ZedDotDevModel::default())
 82    }
 83}
 84
 85impl LanguageModel {
 86    pub fn telemetry_id(&self) -> String {
 87        match self {
 88            LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
 89            LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
 90            LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
 91        }
 92    }
 93
 94    pub fn display_name(&self) -> String {
 95        match self {
 96            LanguageModel::OpenAi(model) => model.display_name().into(),
 97            LanguageModel::Anthropic(model) => model.display_name().into(),
 98            LanguageModel::ZedDotDev(model) => model.display_name().into(),
 99        }
100    }
101
102    pub fn max_token_count(&self) -> usize {
103        match self {
104            LanguageModel::OpenAi(model) => model.max_token_count(),
105            LanguageModel::Anthropic(model) => model.max_token_count(),
106            LanguageModel::ZedDotDev(model) => model.max_token_count(),
107        }
108    }
109
110    pub fn id(&self) -> &str {
111        match self {
112            LanguageModel::OpenAi(model) => model.id(),
113            LanguageModel::Anthropic(model) => model.id(),
114            LanguageModel::ZedDotDev(model) => model.id(),
115        }
116    }
117}
118
119#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
120pub struct LanguageModelRequestMessage {
121    pub role: Role,
122    pub content: String,
123}
124
125impl LanguageModelRequestMessage {
126    pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
127        proto::LanguageModelRequestMessage {
128            role: match self.role {
129                Role::User => proto::LanguageModelRole::LanguageModelUser,
130                Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
131                Role::System => proto::LanguageModelRole::LanguageModelSystem,
132            } as i32,
133            content: self.content.clone(),
134            tool_calls: Vec::new(),
135            tool_call_id: None,
136        }
137    }
138}
139
140#[derive(Debug, Default, Serialize)]
141pub struct LanguageModelRequest {
142    pub model: LanguageModel,
143    pub messages: Vec<LanguageModelRequestMessage>,
144    pub stop: Vec<String>,
145    pub temperature: f32,
146}
147
148impl LanguageModelRequest {
149    pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
150        proto::CompleteWithLanguageModel {
151            model: self.model.id().to_string(),
152            messages: self.messages.iter().map(|m| m.to_proto()).collect(),
153            stop: self.stop.clone(),
154            temperature: self.temperature,
155            tool_choice: None,
156            tools: Vec::new(),
157        }
158    }
159}
160
161#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
162pub struct LanguageModelResponseMessage {
163    pub role: Option<Role>,
164    pub content: Option<String>,
165}
166
167#[derive(Deserialize, Debug)]
168pub struct LanguageModelUsage {
169    pub prompt_tokens: u32,
170    pub completion_tokens: u32,
171    pub total_tokens: u32,
172}
173
174#[derive(Deserialize, Debug)]
175pub struct LanguageModelChoiceDelta {
176    pub index: u32,
177    pub delta: LanguageModelResponseMessage,
178    pub finish_reason: Option<String>,
179}
180
181#[derive(Clone, Debug, Serialize, Deserialize)]
182struct MessageMetadata {
183    role: Role,
184    status: MessageStatus,
185}
186
187#[derive(Clone, Debug, Serialize, Deserialize)]
188enum MessageStatus {
189    Pending,
190    Done,
191    Error(SharedString),
192}
193
194/// The state pertaining to the Assistant.
195#[derive(Default)]
196struct Assistant {
197    /// Whether the Assistant is enabled.
198    enabled: bool,
199}
200
201impl Global for Assistant {}
202
203impl Assistant {
204    const NAMESPACE: &'static str = "assistant";
205
206    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
207        if self.enabled == enabled {
208            return;
209        }
210
211        self.enabled = enabled;
212
213        if !enabled {
214            CommandPaletteFilter::update_global(cx, |filter, _cx| {
215                filter.hide_namespace(Self::NAMESPACE);
216            });
217
218            return;
219        }
220
221        CommandPaletteFilter::update_global(cx, |filter, _cx| {
222            filter.show_namespace(Self::NAMESPACE);
223        });
224    }
225}
226
227pub fn init(client: Arc<Client>, cx: &mut AppContext) {
228    cx.set_global(Assistant::default());
229    AssistantSettings::register(cx);
230    completion_provider::init(client, cx);
231    assistant_panel::init(cx);
232
233    CommandPaletteFilter::update_global(cx, |filter, _cx| {
234        filter.hide_namespace(Assistant::NAMESPACE);
235    });
236    Assistant::update_global(cx, |assistant, cx| {
237        let settings = AssistantSettings::get_global(cx);
238
239        assistant.set_enabled(settings.enabled, cx);
240    });
241    cx.observe_global::<SettingsStore>(|cx| {
242        Assistant::update_global(cx, |assistant, cx| {
243            let settings = AssistantSettings::get_global(cx);
244
245            assistant.set_enabled(settings.enabled, cx);
246        });
247    })
248    .detach();
249}
250
251#[cfg(test)]
252#[ctor::ctor]
253fn init_logger() {
254    if std::env::var("RUST_LOG").is_ok() {
255        env_logger::init();
256    }
257}