1package cmd
2
3import (
4 "bytes"
5 "context"
6 "database/sql"
7 _ "embed"
8 "encoding/json"
9 "fmt"
10 "html/template"
11 "os"
12 "os/user"
13 "path/filepath"
14 "strings"
15 "time"
16
17 "github.com/charmbracelet/crush/internal/config"
18 "github.com/charmbracelet/crush/internal/db"
19 "github.com/charmbracelet/crush/internal/event"
20 "github.com/pkg/browser"
21 "github.com/spf13/cobra"
22)
23
24//go:embed stats/index.html
25var statsTemplate string
26
27//go:embed stats/index.css
28var statsCSS string
29
30//go:embed stats/index.js
31var statsJS string
32
33//go:embed stats/header.svg
34var headerSVG string
35
36//go:embed stats/heartbit.svg
37var heartbitSVG string
38
39//go:embed stats/footer.svg
40var footerSVG string
41
42var statsCmd = &cobra.Command{
43 Use: "stats",
44 Short: "Show usage statistics",
45 Long: "Generate and display usage statistics including token usage, costs, and activity patterns",
46 RunE: runStats,
47}
48
49// Day names for day of week statistics.
50var dayNames = []string{"Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"}
51
52// Stats holds all the statistics data.
53type Stats struct {
54 GeneratedAt time.Time `json:"generated_at"`
55 Total TotalStats `json:"total"`
56 UsageByDay []DailyUsage `json:"usage_by_day"`
57 UsageByModel []ModelUsage `json:"usage_by_model"`
58 UsageByHour []HourlyUsage `json:"usage_by_hour"`
59 UsageByDayOfWeek []DayOfWeekUsage `json:"usage_by_day_of_week"`
60 RecentActivity []DailyActivity `json:"recent_activity"`
61 AvgResponseTimeMs float64 `json:"avg_response_time_ms"`
62 ToolUsage []ToolUsage `json:"tool_usage"`
63 HourDayHeatmap []HourDayHeatmapPt `json:"hour_day_heatmap"`
64}
65
66type TotalStats struct {
67 TotalSessions int64 `json:"total_sessions"`
68 TotalPromptTokens int64 `json:"total_prompt_tokens"`
69 TotalCompletionTokens int64 `json:"total_completion_tokens"`
70 TotalTokens int64 `json:"total_tokens"`
71 TotalCost float64 `json:"total_cost"`
72 TotalMessages int64 `json:"total_messages"`
73 AvgTokensPerSession float64 `json:"avg_tokens_per_session"`
74 AvgMessagesPerSession float64 `json:"avg_messages_per_session"`
75}
76
77type DailyUsage struct {
78 Day string `json:"day"`
79 PromptTokens int64 `json:"prompt_tokens"`
80 CompletionTokens int64 `json:"completion_tokens"`
81 TotalTokens int64 `json:"total_tokens"`
82 Cost float64 `json:"cost"`
83 SessionCount int64 `json:"session_count"`
84}
85
86type ModelUsage struct {
87 Model string `json:"model"`
88 Provider string `json:"provider"`
89 MessageCount int64 `json:"message_count"`
90}
91
92type HourlyUsage struct {
93 Hour int `json:"hour"`
94 SessionCount int64 `json:"session_count"`
95}
96
97type DayOfWeekUsage struct {
98 DayOfWeek int `json:"day_of_week"`
99 DayName string `json:"day_name"`
100 SessionCount int64 `json:"session_count"`
101 PromptTokens int64 `json:"prompt_tokens"`
102 CompletionTokens int64 `json:"completion_tokens"`
103}
104
105type DailyActivity struct {
106 Day string `json:"day"`
107 SessionCount int64 `json:"session_count"`
108 TotalTokens int64 `json:"total_tokens"`
109 Cost float64 `json:"cost"`
110}
111
112type ToolUsage struct {
113 ToolName string `json:"tool_name"`
114 CallCount int64 `json:"call_count"`
115}
116
117type HourDayHeatmapPt struct {
118 DayOfWeek int `json:"day_of_week"`
119 Hour int `json:"hour"`
120 SessionCount int64 `json:"session_count"`
121}
122
123func runStats(cmd *cobra.Command, _ []string) error {
124 event.StatsViewed()
125
126 dataDir, _ := cmd.Flags().GetString("data-dir")
127 ctx := cmd.Context()
128
129 if dataDir == "" {
130 cfg, err := config.Init("", "", false)
131 if err != nil {
132 return fmt.Errorf("failed to initialize config: %w", err)
133 }
134 dataDir = cfg.Config().Options.DataDirectory
135 }
136
137 conn, err := db.Connect(ctx, dataDir)
138 if err != nil {
139 return fmt.Errorf("failed to connect to database: %w", err)
140 }
141 defer conn.Close()
142
143 stats, err := gatherStats(ctx, conn)
144 if err != nil {
145 return fmt.Errorf("failed to gather stats: %w", err)
146 }
147
148 if stats.Total.TotalSessions == 0 {
149 return fmt.Errorf("no data available: no sessions found in database")
150 }
151
152 currentUser, err := user.Current()
153 if err != nil {
154 return fmt.Errorf("failed to get current user: %w", err)
155 }
156 username := currentUser.Username
157 project, err := os.Getwd()
158 if err != nil {
159 return fmt.Errorf("failed to get current directory: %w", err)
160 }
161 project = strings.Replace(project, currentUser.HomeDir, "~", 1)
162
163 htmlPath := filepath.Join(dataDir, "stats/index.html")
164 if err := generateHTML(stats, project, username, htmlPath); err != nil {
165 return fmt.Errorf("failed to generate HTML: %w", err)
166 }
167
168 fmt.Printf("Stats generated: %s\n", htmlPath)
169
170 if err := browser.OpenFile(htmlPath); err != nil {
171 fmt.Printf("Could not open browser: %v\n", err)
172 fmt.Println("Please open the file manually.")
173 }
174
175 return nil
176}
177
178func gatherStats(ctx context.Context, conn *sql.DB) (*Stats, error) {
179 queries := db.New(conn)
180
181 stats := &Stats{
182 GeneratedAt: time.Now(),
183 }
184
185 // Total stats.
186 total, err := queries.GetTotalStats(ctx)
187 if err != nil {
188 return nil, fmt.Errorf("get total stats: %w", err)
189 }
190 stats.Total = TotalStats{
191 TotalSessions: total.TotalSessions,
192 TotalPromptTokens: toInt64(total.TotalPromptTokens),
193 TotalCompletionTokens: toInt64(total.TotalCompletionTokens),
194 TotalTokens: toInt64(total.TotalPromptTokens) + toInt64(total.TotalCompletionTokens),
195 TotalCost: toFloat64(total.TotalCost),
196 TotalMessages: toInt64(total.TotalMessages),
197 AvgTokensPerSession: toFloat64(total.AvgTokensPerSession),
198 AvgMessagesPerSession: toFloat64(total.AvgMessagesPerSession),
199 }
200
201 // Usage by day.
202 dailyUsage, err := queries.GetUsageByDay(ctx)
203 if err != nil {
204 return nil, fmt.Errorf("get usage by day: %w", err)
205 }
206 for _, d := range dailyUsage {
207 prompt := nullFloat64ToInt64(d.PromptTokens)
208 completion := nullFloat64ToInt64(d.CompletionTokens)
209 stats.UsageByDay = append(stats.UsageByDay, DailyUsage{
210 Day: fmt.Sprintf("%v", d.Day),
211 PromptTokens: prompt,
212 CompletionTokens: completion,
213 TotalTokens: prompt + completion,
214 Cost: d.Cost.Float64,
215 SessionCount: d.SessionCount,
216 })
217 }
218
219 // Usage by model.
220 modelUsage, err := queries.GetUsageByModel(ctx)
221 if err != nil {
222 return nil, fmt.Errorf("get usage by model: %w", err)
223 }
224 for _, m := range modelUsage {
225 stats.UsageByModel = append(stats.UsageByModel, ModelUsage{
226 Model: m.Model,
227 Provider: m.Provider,
228 MessageCount: m.MessageCount,
229 })
230 }
231
232 // Usage by hour.
233 hourlyUsage, err := queries.GetUsageByHour(ctx)
234 if err != nil {
235 return nil, fmt.Errorf("get usage by hour: %w", err)
236 }
237 for _, h := range hourlyUsage {
238 stats.UsageByHour = append(stats.UsageByHour, HourlyUsage{
239 Hour: int(h.Hour),
240 SessionCount: h.SessionCount,
241 })
242 }
243
244 // Usage by day of week.
245 dowUsage, err := queries.GetUsageByDayOfWeek(ctx)
246 if err != nil {
247 return nil, fmt.Errorf("get usage by day of week: %w", err)
248 }
249 for _, d := range dowUsage {
250 stats.UsageByDayOfWeek = append(stats.UsageByDayOfWeek, DayOfWeekUsage{
251 DayOfWeek: int(d.DayOfWeek),
252 DayName: dayNames[int(d.DayOfWeek)],
253 SessionCount: d.SessionCount,
254 PromptTokens: nullFloat64ToInt64(d.PromptTokens),
255 CompletionTokens: nullFloat64ToInt64(d.CompletionTokens),
256 })
257 }
258
259 // Recent activity (last 30 days).
260 recent, err := queries.GetRecentActivity(ctx)
261 if err != nil {
262 return nil, fmt.Errorf("get recent activity: %w", err)
263 }
264 for _, r := range recent {
265 stats.RecentActivity = append(stats.RecentActivity, DailyActivity{
266 Day: fmt.Sprintf("%v", r.Day),
267 SessionCount: r.SessionCount,
268 TotalTokens: nullFloat64ToInt64(r.TotalTokens),
269 Cost: r.Cost.Float64,
270 })
271 }
272
273 // Average response time.
274 avgResp, err := queries.GetAverageResponseTime(ctx)
275 if err != nil {
276 return nil, fmt.Errorf("get average response time: %w", err)
277 }
278 stats.AvgResponseTimeMs = toFloat64(avgResp) * 1000
279
280 // Tool usage.
281 toolUsage, err := queries.GetToolUsage(ctx)
282 if err != nil {
283 return nil, fmt.Errorf("get tool usage: %w", err)
284 }
285 for _, t := range toolUsage {
286 if name, ok := t.ToolName.(string); ok && name != "" {
287 stats.ToolUsage = append(stats.ToolUsage, ToolUsage{
288 ToolName: name,
289 CallCount: t.CallCount,
290 })
291 }
292 }
293
294 // Hour/day heatmap.
295 heatmap, err := queries.GetHourDayHeatmap(ctx)
296 if err != nil {
297 return nil, fmt.Errorf("get hour day heatmap: %w", err)
298 }
299 for _, h := range heatmap {
300 stats.HourDayHeatmap = append(stats.HourDayHeatmap, HourDayHeatmapPt{
301 DayOfWeek: int(h.DayOfWeek),
302 Hour: int(h.Hour),
303 SessionCount: h.SessionCount,
304 })
305 }
306
307 return stats, nil
308}
309
310func toInt64(v any) int64 {
311 switch val := v.(type) {
312 case int64:
313 return val
314 case float64:
315 return int64(val)
316 case int:
317 return int64(val)
318 default:
319 return 0
320 }
321}
322
323func toFloat64(v any) float64 {
324 switch val := v.(type) {
325 case float64:
326 return val
327 case int64:
328 return float64(val)
329 case int:
330 return float64(val)
331 default:
332 return 0
333 }
334}
335
336func nullFloat64ToInt64(n sql.NullFloat64) int64 {
337 if n.Valid {
338 return int64(n.Float64)
339 }
340 return 0
341}
342
343func generateHTML(stats *Stats, projName, username, path string) error {
344 statsJSON, err := json.Marshal(stats)
345 if err != nil {
346 return err
347 }
348
349 tmpl, err := template.New("stats").Parse(statsTemplate)
350 if err != nil {
351 return fmt.Errorf("parse template: %w", err)
352 }
353
354 data := struct {
355 StatsJSON template.JS
356 CSS template.CSS
357 JS template.JS
358 Header template.HTML
359 Heartbit template.HTML
360 Footer template.HTML
361 GeneratedAt string
362 ProjectName string
363 Username string
364 }{
365 StatsJSON: template.JS(statsJSON),
366 CSS: template.CSS(statsCSS),
367 JS: template.JS(statsJS),
368 Header: template.HTML(headerSVG),
369 Heartbit: template.HTML(heartbitSVG),
370 Footer: template.HTML(footerSVG),
371 GeneratedAt: stats.GeneratedAt.Format("2006-01-02"),
372 ProjectName: projName,
373 Username: username,
374 }
375
376 var buf bytes.Buffer
377 if err := tmpl.Execute(&buf, data); err != nil {
378 return fmt.Errorf("execute template: %w", err)
379 }
380
381 // Ensure parent directory exists.
382 if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
383 return fmt.Errorf("create directory: %w", err)
384 }
385
386 return os.WriteFile(path, buf.Bytes(), 0o644)
387}