1package cmd
2
3import (
4 "bytes"
5 "context"
6 "database/sql"
7 _ "embed"
8 "encoding/json"
9 "fmt"
10 "html/template"
11 "os"
12 "os/user"
13 "path/filepath"
14 "strings"
15 "time"
16
17 "github.com/charmbracelet/crush/internal/config"
18 "github.com/charmbracelet/crush/internal/db"
19 "github.com/pkg/browser"
20 "github.com/spf13/cobra"
21)
22
23//go:embed stats/index.html
24var statsTemplate string
25
26//go:embed stats/index.css
27var statsCSS string
28
29//go:embed stats/index.js
30var statsJS string
31
32//go:embed stats/header.svg
33var headerSVG string
34
35//go:embed stats/heartbit.svg
36var heartbitSVG string
37
38//go:embed stats/footer.svg
39var footerSVG string
40
41var statsCmd = &cobra.Command{
42 Use: "stats",
43 Short: "Show usage statistics",
44 Long: "Generate and display usage statistics including token usage, costs, and activity patterns",
45 RunE: runStats,
46}
47
48// Day names for day of week statistics.
49var dayNames = []string{"Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"}
50
51// Stats holds all the statistics data.
52type Stats struct {
53 GeneratedAt time.Time `json:"generated_at"`
54 Total TotalStats `json:"total"`
55 UsageByDay []DailyUsage `json:"usage_by_day"`
56 UsageByModel []ModelUsage `json:"usage_by_model"`
57 UsageByHour []HourlyUsage `json:"usage_by_hour"`
58 UsageByDayOfWeek []DayOfWeekUsage `json:"usage_by_day_of_week"`
59 RecentActivity []DailyActivity `json:"recent_activity"`
60 AvgResponseTimeMs float64 `json:"avg_response_time_ms"`
61 ToolUsage []ToolUsage `json:"tool_usage"`
62 HourDayHeatmap []HourDayHeatmapPt `json:"hour_day_heatmap"`
63}
64
65type TotalStats struct {
66 TotalSessions int64 `json:"total_sessions"`
67 TotalPromptTokens int64 `json:"total_prompt_tokens"`
68 TotalCompletionTokens int64 `json:"total_completion_tokens"`
69 TotalTokens int64 `json:"total_tokens"`
70 TotalCost float64 `json:"total_cost"`
71 TotalMessages int64 `json:"total_messages"`
72 AvgTokensPerSession float64 `json:"avg_tokens_per_session"`
73 AvgMessagesPerSession float64 `json:"avg_messages_per_session"`
74}
75
76type DailyUsage struct {
77 Day string `json:"day"`
78 PromptTokens int64 `json:"prompt_tokens"`
79 CompletionTokens int64 `json:"completion_tokens"`
80 TotalTokens int64 `json:"total_tokens"`
81 Cost float64 `json:"cost"`
82 SessionCount int64 `json:"session_count"`
83}
84
85type ModelUsage struct {
86 Model string `json:"model"`
87 Provider string `json:"provider"`
88 MessageCount int64 `json:"message_count"`
89}
90
91type HourlyUsage struct {
92 Hour int `json:"hour"`
93 SessionCount int64 `json:"session_count"`
94}
95
96type DayOfWeekUsage struct {
97 DayOfWeek int `json:"day_of_week"`
98 DayName string `json:"day_name"`
99 SessionCount int64 `json:"session_count"`
100 PromptTokens int64 `json:"prompt_tokens"`
101 CompletionTokens int64 `json:"completion_tokens"`
102}
103
104type DailyActivity struct {
105 Day string `json:"day"`
106 SessionCount int64 `json:"session_count"`
107 TotalTokens int64 `json:"total_tokens"`
108 Cost float64 `json:"cost"`
109}
110
111type ToolUsage struct {
112 ToolName string `json:"tool_name"`
113 CallCount int64 `json:"call_count"`
114}
115
116type HourDayHeatmapPt struct {
117 DayOfWeek int `json:"day_of_week"`
118 Hour int `json:"hour"`
119 SessionCount int64 `json:"session_count"`
120}
121
122func runStats(cmd *cobra.Command, _ []string) error {
123 dataDir, _ := cmd.Flags().GetString("data-dir")
124 ctx := cmd.Context()
125
126 if dataDir == "" {
127 cfg, err := config.Init("", "", false)
128 if err != nil {
129 return fmt.Errorf("failed to initialize config: %w", err)
130 }
131 dataDir = cfg.Options.DataDirectory
132 }
133
134 conn, err := db.Connect(ctx, dataDir)
135 if err != nil {
136 return fmt.Errorf("failed to connect to database: %w", err)
137 }
138 defer conn.Close()
139
140 stats, err := gatherStats(ctx, conn)
141 if err != nil {
142 return fmt.Errorf("failed to gather stats: %w", err)
143 }
144
145 if stats.Total.TotalSessions == 0 {
146 return fmt.Errorf("no data available: no sessions found in database")
147 }
148
149 currentUser, err := user.Current()
150 if err != nil {
151 return fmt.Errorf("failed to get current user: %w", err)
152 }
153 username := currentUser.Username
154 project, err := os.Getwd()
155 if err != nil {
156 return fmt.Errorf("failed to get current directory: %w", err)
157 }
158 project = strings.Replace(project, currentUser.HomeDir, "~", 1)
159
160 htmlPath := filepath.Join(dataDir, "stats/index.html")
161 if err := generateHTML(stats, project, username, htmlPath); err != nil {
162 return fmt.Errorf("failed to generate HTML: %w", err)
163 }
164
165 fmt.Printf("Stats generated: %s\n", htmlPath)
166
167 if err := browser.OpenFile(htmlPath); err != nil {
168 fmt.Printf("Could not open browser: %v\n", err)
169 fmt.Println("Please open the file manually.")
170 }
171
172 return nil
173}
174
175func gatherStats(ctx context.Context, conn *sql.DB) (*Stats, error) {
176 queries := db.New(conn)
177
178 stats := &Stats{
179 GeneratedAt: time.Now(),
180 }
181
182 // Total stats.
183 total, err := queries.GetTotalStats(ctx)
184 if err != nil {
185 return nil, fmt.Errorf("get total stats: %w", err)
186 }
187 stats.Total = TotalStats{
188 TotalSessions: total.TotalSessions,
189 TotalPromptTokens: toInt64(total.TotalPromptTokens),
190 TotalCompletionTokens: toInt64(total.TotalCompletionTokens),
191 TotalTokens: toInt64(total.TotalPromptTokens) + toInt64(total.TotalCompletionTokens),
192 TotalCost: toFloat64(total.TotalCost),
193 TotalMessages: toInt64(total.TotalMessages),
194 AvgTokensPerSession: toFloat64(total.AvgTokensPerSession),
195 AvgMessagesPerSession: toFloat64(total.AvgMessagesPerSession),
196 }
197
198 // Usage by day.
199 dailyUsage, err := queries.GetUsageByDay(ctx)
200 if err != nil {
201 return nil, fmt.Errorf("get usage by day: %w", err)
202 }
203 for _, d := range dailyUsage {
204 prompt := nullFloat64ToInt64(d.PromptTokens)
205 completion := nullFloat64ToInt64(d.CompletionTokens)
206 stats.UsageByDay = append(stats.UsageByDay, DailyUsage{
207 Day: fmt.Sprintf("%v", d.Day),
208 PromptTokens: prompt,
209 CompletionTokens: completion,
210 TotalTokens: prompt + completion,
211 Cost: d.Cost.Float64,
212 SessionCount: d.SessionCount,
213 })
214 }
215
216 // Usage by model.
217 modelUsage, err := queries.GetUsageByModel(ctx)
218 if err != nil {
219 return nil, fmt.Errorf("get usage by model: %w", err)
220 }
221 for _, m := range modelUsage {
222 stats.UsageByModel = append(stats.UsageByModel, ModelUsage{
223 Model: m.Model,
224 Provider: m.Provider,
225 MessageCount: m.MessageCount,
226 })
227 }
228
229 // Usage by hour.
230 hourlyUsage, err := queries.GetUsageByHour(ctx)
231 if err != nil {
232 return nil, fmt.Errorf("get usage by hour: %w", err)
233 }
234 for _, h := range hourlyUsage {
235 stats.UsageByHour = append(stats.UsageByHour, HourlyUsage{
236 Hour: int(h.Hour),
237 SessionCount: h.SessionCount,
238 })
239 }
240
241 // Usage by day of week.
242 dowUsage, err := queries.GetUsageByDayOfWeek(ctx)
243 if err != nil {
244 return nil, fmt.Errorf("get usage by day of week: %w", err)
245 }
246 for _, d := range dowUsage {
247 stats.UsageByDayOfWeek = append(stats.UsageByDayOfWeek, DayOfWeekUsage{
248 DayOfWeek: int(d.DayOfWeek),
249 DayName: dayNames[int(d.DayOfWeek)],
250 SessionCount: d.SessionCount,
251 PromptTokens: nullFloat64ToInt64(d.PromptTokens),
252 CompletionTokens: nullFloat64ToInt64(d.CompletionTokens),
253 })
254 }
255
256 // Recent activity (last 30 days).
257 recent, err := queries.GetRecentActivity(ctx)
258 if err != nil {
259 return nil, fmt.Errorf("get recent activity: %w", err)
260 }
261 for _, r := range recent {
262 stats.RecentActivity = append(stats.RecentActivity, DailyActivity{
263 Day: fmt.Sprintf("%v", r.Day),
264 SessionCount: r.SessionCount,
265 TotalTokens: nullFloat64ToInt64(r.TotalTokens),
266 Cost: r.Cost.Float64,
267 })
268 }
269
270 // Average response time.
271 avgResp, err := queries.GetAverageResponseTime(ctx)
272 if err != nil {
273 return nil, fmt.Errorf("get average response time: %w", err)
274 }
275 stats.AvgResponseTimeMs = toFloat64(avgResp) * 1000
276
277 // Tool usage.
278 toolUsage, err := queries.GetToolUsage(ctx)
279 if err != nil {
280 return nil, fmt.Errorf("get tool usage: %w", err)
281 }
282 for _, t := range toolUsage {
283 if name, ok := t.ToolName.(string); ok && name != "" {
284 stats.ToolUsage = append(stats.ToolUsage, ToolUsage{
285 ToolName: name,
286 CallCount: t.CallCount,
287 })
288 }
289 }
290
291 // Hour/day heatmap.
292 heatmap, err := queries.GetHourDayHeatmap(ctx)
293 if err != nil {
294 return nil, fmt.Errorf("get hour day heatmap: %w", err)
295 }
296 for _, h := range heatmap {
297 stats.HourDayHeatmap = append(stats.HourDayHeatmap, HourDayHeatmapPt{
298 DayOfWeek: int(h.DayOfWeek),
299 Hour: int(h.Hour),
300 SessionCount: h.SessionCount,
301 })
302 }
303
304 return stats, nil
305}
306
307func toInt64(v any) int64 {
308 switch val := v.(type) {
309 case int64:
310 return val
311 case float64:
312 return int64(val)
313 case int:
314 return int64(val)
315 default:
316 return 0
317 }
318}
319
320func toFloat64(v any) float64 {
321 switch val := v.(type) {
322 case float64:
323 return val
324 case int64:
325 return float64(val)
326 case int:
327 return float64(val)
328 default:
329 return 0
330 }
331}
332
333func nullFloat64ToInt64(n sql.NullFloat64) int64 {
334 if n.Valid {
335 return int64(n.Float64)
336 }
337 return 0
338}
339
340func generateHTML(stats *Stats, projName, username, path string) error {
341 statsJSON, err := json.Marshal(stats)
342 if err != nil {
343 return err
344 }
345
346 tmpl, err := template.New("stats").Parse(statsTemplate)
347 if err != nil {
348 return fmt.Errorf("parse template: %w", err)
349 }
350
351 data := struct {
352 StatsJSON template.JS
353 CSS template.CSS
354 JS template.JS
355 Header template.HTML
356 Heartbit template.HTML
357 Footer template.HTML
358 GeneratedAt string
359 ProjectName string
360 Username string
361 }{
362 StatsJSON: template.JS(statsJSON),
363 CSS: template.CSS(statsCSS),
364 JS: template.JS(statsJS),
365 Header: template.HTML(headerSVG),
366 Heartbit: template.HTML(heartbitSVG),
367 Footer: template.HTML(footerSVG),
368 GeneratedAt: stats.GeneratedAt.Format("2006-01-02"),
369 ProjectName: projName,
370 Username: username,
371 }
372
373 var buf bytes.Buffer
374 if err := tmpl.Execute(&buf, data); err != nil {
375 return fmt.Errorf("execute template: %w", err)
376 }
377
378 // Ensure parent directory exists.
379 if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
380 return fmt.Errorf("create directory: %w", err)
381 }
382
383 return os.WriteFile(path, buf.Bytes(), 0o644)
384}