1pub mod assistant_panel;
2pub mod assistant_settings;
3mod codegen;
4mod completion_provider;
5mod prompts;
6mod saved_conversation;
7mod streaming_diff;
8
9pub use assistant_panel::AssistantPanel;
10use assistant_settings::{AnthropicModel, AssistantSettings, OpenAiModel, ZedDotDevModel};
11use client::{proto, Client};
12use command_palette_hooks::CommandPaletteFilter;
13pub(crate) use completion_provider::*;
14use gpui::{actions, AppContext, BorrowAppContext, Global, SharedString};
15pub(crate) use saved_conversation::*;
16use serde::{Deserialize, Serialize};
17use settings::{Settings, SettingsStore};
18use std::{
19 fmt::{self, Display},
20 sync::Arc,
21};
22
23actions!(
24 assistant,
25 [
26 Assist,
27 Split,
28 CycleMessageRole,
29 QuoteSelection,
30 ToggleFocus,
31 ResetKey,
32 InlineAssist,
33 ToggleIncludeConversation,
34 ToggleHistory,
35 ]
36);
37
38#[derive(
39 Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
40)]
41struct MessageId(usize);
42
43#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
44#[serde(rename_all = "lowercase")]
45pub enum Role {
46 User,
47 Assistant,
48 System,
49}
50
51impl Role {
52 pub fn cycle(&mut self) {
53 *self = match self {
54 Role::User => Role::Assistant,
55 Role::Assistant => Role::System,
56 Role::System => Role::User,
57 }
58 }
59}
60
61impl Display for Role {
62 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
63 match self {
64 Role::User => write!(f, "user"),
65 Role::Assistant => write!(f, "assistant"),
66 Role::System => write!(f, "system"),
67 }
68 }
69}
70
71#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
72pub enum LanguageModel {
73 ZedDotDev(ZedDotDevModel),
74 OpenAi(OpenAiModel),
75 Anthropic(AnthropicModel),
76}
77
78impl Default for LanguageModel {
79 fn default() -> Self {
80 LanguageModel::ZedDotDev(ZedDotDevModel::default())
81 }
82}
83
84impl LanguageModel {
85 pub fn telemetry_id(&self) -> String {
86 match self {
87 LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
88 LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
89 LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
90 }
91 }
92
93 pub fn display_name(&self) -> String {
94 match self {
95 LanguageModel::OpenAi(model) => model.display_name().into(),
96 LanguageModel::Anthropic(model) => model.display_name().into(),
97 LanguageModel::ZedDotDev(model) => model.display_name().into(),
98 }
99 }
100
101 pub fn max_token_count(&self) -> usize {
102 match self {
103 LanguageModel::OpenAi(model) => model.max_token_count(),
104 LanguageModel::Anthropic(model) => model.max_token_count(),
105 LanguageModel::ZedDotDev(model) => model.max_token_count(),
106 }
107 }
108
109 pub fn id(&self) -> &str {
110 match self {
111 LanguageModel::OpenAi(model) => model.id(),
112 LanguageModel::Anthropic(model) => model.id(),
113 LanguageModel::ZedDotDev(model) => model.id(),
114 }
115 }
116}
117
118#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
119pub struct LanguageModelRequestMessage {
120 pub role: Role,
121 pub content: String,
122}
123
124impl LanguageModelRequestMessage {
125 pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
126 proto::LanguageModelRequestMessage {
127 role: match self.role {
128 Role::User => proto::LanguageModelRole::LanguageModelUser,
129 Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
130 Role::System => proto::LanguageModelRole::LanguageModelSystem,
131 } as i32,
132 content: self.content.clone(),
133 tool_calls: Vec::new(),
134 tool_call_id: None,
135 }
136 }
137}
138
139#[derive(Debug, Default, Serialize)]
140pub struct LanguageModelRequest {
141 pub model: LanguageModel,
142 pub messages: Vec<LanguageModelRequestMessage>,
143 pub stop: Vec<String>,
144 pub temperature: f32,
145}
146
147impl LanguageModelRequest {
148 pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
149 proto::CompleteWithLanguageModel {
150 model: self.model.id().to_string(),
151 messages: self.messages.iter().map(|m| m.to_proto()).collect(),
152 stop: self.stop.clone(),
153 temperature: self.temperature,
154 tool_choice: None,
155 tools: Vec::new(),
156 }
157 }
158}
159
160#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
161pub struct LanguageModelResponseMessage {
162 pub role: Option<Role>,
163 pub content: Option<String>,
164}
165
166#[derive(Deserialize, Debug)]
167pub struct LanguageModelUsage {
168 pub prompt_tokens: u32,
169 pub completion_tokens: u32,
170 pub total_tokens: u32,
171}
172
173#[derive(Deserialize, Debug)]
174pub struct LanguageModelChoiceDelta {
175 pub index: u32,
176 pub delta: LanguageModelResponseMessage,
177 pub finish_reason: Option<String>,
178}
179
180#[derive(Clone, Debug, Serialize, Deserialize)]
181struct MessageMetadata {
182 role: Role,
183 status: MessageStatus,
184}
185
186#[derive(Clone, Debug, Serialize, Deserialize)]
187enum MessageStatus {
188 Pending,
189 Done,
190 Error(SharedString),
191}
192
193/// The state pertaining to the Assistant.
194#[derive(Default)]
195struct Assistant {
196 /// Whether the Assistant is enabled.
197 enabled: bool,
198}
199
200impl Global for Assistant {}
201
202impl Assistant {
203 const NAMESPACE: &'static str = "assistant";
204
205 fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
206 if self.enabled == enabled {
207 return;
208 }
209
210 self.enabled = enabled;
211
212 if !enabled {
213 CommandPaletteFilter::update_global(cx, |filter, _cx| {
214 filter.hide_namespace(Self::NAMESPACE);
215 });
216
217 return;
218 }
219
220 CommandPaletteFilter::update_global(cx, |filter, _cx| {
221 filter.show_namespace(Self::NAMESPACE);
222 });
223 }
224}
225
226pub fn init(client: Arc<Client>, cx: &mut AppContext) {
227 cx.set_global(Assistant::default());
228 AssistantSettings::register(cx);
229 completion_provider::init(client, cx);
230 assistant_panel::init(cx);
231
232 CommandPaletteFilter::update_global(cx, |filter, _cx| {
233 filter.hide_namespace(Assistant::NAMESPACE);
234 });
235 cx.update_global(|assistant: &mut Assistant, cx: &mut AppContext| {
236 let settings = AssistantSettings::get_global(cx);
237
238 assistant.set_enabled(settings.enabled, cx);
239 });
240 cx.observe_global::<SettingsStore>(|cx| {
241 cx.update_global(|assistant: &mut Assistant, cx: &mut AppContext| {
242 let settings = AssistantSettings::get_global(cx);
243
244 assistant.set_enabled(settings.enabled, cx);
245 });
246 })
247 .detach();
248}
249
250#[cfg(test)]
251#[ctor::ctor]
252fn init_logger() {
253 if std::env::var("RUST_LOG").is_ok() {
254 env_logger::init();
255 }
256}