Detailed changes
@@ -959,6 +959,46 @@
"created_at": "2025-12-14T09:41:12Z",
"repoId": 987670088,
"pullRequestNo": 1628
+ },
+ {
+ "name": "flatsponge",
+ "id": 104839509,
+ "comment_id": 3673002560,
+ "created_at": "2025-12-19T01:11:45Z",
+ "repoId": 987670088,
+ "pullRequestNo": 1668
+ },
+ {
+ "name": "jonhoo",
+ "id": 176295,
+ "comment_id": 3674853134,
+ "created_at": "2025-12-19T12:14:08Z",
+ "repoId": 987670088,
+ "pullRequestNo": 1675
+ },
+ {
+ "name": "Mr777x-enf",
+ "id": 248610315,
+ "comment_id": 3682737876,
+ "created_at": "2025-12-22T16:07:47Z",
+ "repoId": 987670088,
+ "pullRequestNo": 1694
+ },
+ {
+ "name": "yuguorui",
+ "id": 6182414,
+ "comment_id": 3687495909,
+ "created_at": "2025-12-23T17:59:11Z",
+ "repoId": 987670088,
+ "pullRequestNo": 1709
+ },
+ {
+ "name": "aeroxy",
+ "id": 2761307,
+ "comment_id": 3693734613,
+ "created_at": "2025-12-27T06:01:58Z",
+ "repoId": 987670088,
+ "pullRequestNo": 1723
}
]
}
@@ -19,7 +19,7 @@ jobs:
go-version-file: go.mod
- run: go run . schema > ./schema.json
- run: go generate ./internal/agent/hyper/...
- - uses: stefanzweifel/git-auto-commit-action@28e16e81777b558cc906c8750092100bbb34c5e3 # v5
+ - uses: stefanzweifel/git-auto-commit-action@04702edda442b2e678b25b537cec683a1493fcb9 # v5
with:
commit_message: "chore: auto-update files"
branch: main
@@ -362,6 +362,49 @@ completely hidden from the agent.
To disable tools from MCP servers, see the [MCP config section](#mcps).
+### Agent Skills
+
+Crush supports the [Agent Skills](https://agentskills.io) open standard for
+extending agent capabilities with reusable skill packages. Skills are folders
+containing a `SKILL.md` file with instructions that Crush can discover and
+activate on demand.
+
+Skills are discovered from:
+
+- `~/.config/crush/skills/` on Unix (default, can be overridden with `CRUSH_SKILLS_DIR`)
+- `%LOCALAPPDATA%\crush\skills\` on Windows (default, can be overridden with `CRUSH_SKILLS_DIR`)
+- Additional paths configured via `options.skills_paths`
+
+```jsonc
+{
+ "$schema": "https://charm.land/crush.json",
+ "options": {
+ "skills_paths": [
+ "~/.config/crush/skills", // Windows: "%LOCALAPPDATA%\\crush\\skills",
+ "./project-skills"
+ ]
+ }
+}
+```
+
+You can get started with example skills from [anthropics/skills](https://github.com/anthropics/skills):
+
+```bash
+# Unix
+mkdir -p ~/.config/crush/skills
+cd ~/.config/crush/skills
+git clone https://github.com/anthropics/skills.git _temp
+mv _temp/skills/* . && rm -rf _temp
+```
+
+```powershell
+# Windows (PowerShell)
+mkdir -Force "$env:LOCALAPPDATA\crush\skills"
+cd "$env:LOCALAPPDATA\crush\skills"
+git clone https://github.com/anthropics/skills.git _temp
+mv _temp/skills/* . ; rm -r -force _temp
+```
+
### Initialization
When you initialize a project, Crush analyzes your codebase and creates
@@ -5,7 +5,7 @@ go 1.25.5
require (
charm.land/bubbles/v2 v2.0.0-rc.1
charm.land/bubbletea/v2 v2.0.0-rc.2.0.20251216153312-819e2e89c62e
- charm.land/fantasy v0.5.3
+ charm.land/fantasy v0.5.5
charm.land/glamour/v2 v2.0.0-20251110203732-69649f93d3b1
charm.land/lipgloss/v2 v2.0.0-beta.3.0.20251205162909-7869489d8971
charm.land/log/v2 v2.0.0-20251110204020-529bb77f35da
@@ -13,12 +13,12 @@ require (
github.com/JohannesKaufmann/html-to-markdown v1.6.0
github.com/MakeNowJust/heredoc v1.0.0
github.com/PuerkitoBio/goquery v1.11.0
- github.com/alecthomas/chroma/v2 v2.20.0
+ github.com/alecthomas/chroma/v2 v2.21.1
github.com/atotto/clipboard v0.1.4
github.com/aymanbagabas/go-udiff v0.3.1
github.com/bmatcuk/doublestar/v4 v4.9.1
github.com/charlievieth/fastwalk v1.0.14
- github.com/charmbracelet/catwalk v0.11.0
+ github.com/charmbracelet/catwalk v0.11.2
github.com/charmbracelet/colorprofile v0.4.1
github.com/charmbracelet/fang v0.4.4
github.com/charmbracelet/ultraviolet v0.0.0-20251212194010-b927aa605560
@@ -39,7 +39,7 @@ require (
github.com/lucasb-eyer/go-colorful v1.3.0
github.com/modelcontextprotocol/go-sdk v1.1.0
github.com/muesli/termenv v0.16.0
- github.com/ncruces/go-sqlite3 v0.30.3
+ github.com/ncruces/go-sqlite3 v0.30.4
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
github.com/nxadm/tail v1.4.11
github.com/openai/openai-go/v2 v2.7.1
@@ -62,6 +62,7 @@ require (
golang.org/x/sync v0.19.0
golang.org/x/text v0.32.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
+ gopkg.in/yaml.v3 v3.0.1
mvdan.cc/sh/moreinterp v0.0.0-20250902163504-3cf4fd5717a5
mvdan.cc/sh/v3 v3.12.1-0.20250902163504-3cf4fd5717a5
)
@@ -145,7 +146,7 @@ require (
github.com/sethvargo/go-retry v0.3.0 // indirect
github.com/sourcegraph/jsonrpc2 v0.2.1 // indirect
github.com/spf13/pflag v1.0.9 // indirect
- github.com/tetratelabs/wazero v1.10.1 // indirect
+ github.com/tetratelabs/wazero v1.11.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/u-root/u-root v0.14.1-0.20250807200646-5e7721023dc7 // indirect
@@ -177,5 +178,4 @@ require (
google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/dnaeon/go-vcr.v4 v4.0.6-0.20251110073552-01de4eb40290 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
- gopkg.in/yaml.v3 v3.0.1 // indirect
)
@@ -2,8 +2,8 @@ charm.land/bubbles/v2 v2.0.0-rc.1 h1:EiIFVAc3Zi/yY86td+79mPhHR7AqZ1OxF+6ztpOCRaM
charm.land/bubbles/v2 v2.0.0-rc.1/go.mod h1:5AbN6cEd/47gkEf8TgiQ2O3RZ5QxMS14l9W+7F9fPC4=
charm.land/bubbletea/v2 v2.0.0-rc.2.0.20251216153312-819e2e89c62e h1:tXwTmgGpwZT7ParKF5xbEQBVjM2e1uKhKi/GpfU3mYQ=
charm.land/bubbletea/v2 v2.0.0-rc.2.0.20251216153312-819e2e89c62e/go.mod h1:pDM18flq3Z4njKZPA3zCvyVSSIJbMcoqlE82BdGUtL8=
-charm.land/fantasy v0.5.3 h1:+6meCTaH9lrqrcVTEBgsaSkkY0ctC/6dtIufKZcMdMI=
-charm.land/fantasy v0.5.3/go.mod h1:WnH5fJJRMGylx1fL1ow9Kfq0+sPMr5fenpHYAnoTlTg=
+charm.land/fantasy v0.5.5 h1:Dw/NBLH9HLX/ouCz604RXGD7BYzr0lT56/B4ylMGZjg=
+charm.land/fantasy v0.5.5/go.mod h1:QyJLJGissYdBifvitgAxFcYhNACSr0G1faC75CIESUk=
charm.land/glamour/v2 v2.0.0-20251110203732-69649f93d3b1 h1:9q4+yyU7105T3OrOx0csMyKnw89yMSijJ+rVld/Z2ek=
charm.land/glamour/v2 v2.0.0-20251110203732-69649f93d3b1/go.mod h1:J3kVhY6oHXZq5f+8vC3hmDO95fEvbqj3z7xDwxrfzU8=
charm.land/lipgloss/v2 v2.0.0-beta.3.0.20251205162909-7869489d8971 h1:xZFcNsJMiIDbFtWRyDmkKNk1sjojfaom4Zoe0cyH/8c=
@@ -39,10 +39,10 @@ github.com/RealAlexandreAI/json-repair v0.0.14 h1:4kTqotVonDVTio5n2yweRUELVcNe2x
github.com/RealAlexandreAI/json-repair v0.0.14/go.mod h1:GKJi5borR78O8c7HCVbgqjhoiVibZ6hJldxbc6dGrAI=
github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0=
github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k=
-github.com/alecthomas/chroma/v2 v2.20.0 h1:sfIHpxPyR07/Oylvmcai3X/exDlE8+FA820NTz+9sGw=
-github.com/alecthomas/chroma/v2 v2.20.0/go.mod h1:e7tViK0xh/Nf4BYHl00ycY6rV7b8iXBksI9E359yNmA=
-github.com/alecthomas/repr v0.5.1 h1:E3G4t2QbHTSNpPKBgMTln5KLkZHLOcU7r37J4pXBuIg=
-github.com/alecthomas/repr v0.5.1/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
+github.com/alecthomas/chroma/v2 v2.21.1 h1:FaSDrp6N+3pphkNKU6HPCiYLgm8dbe5UXIXcoBhZSWA=
+github.com/alecthomas/chroma/v2 v2.21.1/go.mod h1:NqVhfBR0lte5Ouh3DcthuUCTUpDC9cxBOfyMbMQPs3o=
+github.com/alecthomas/repr v0.5.2 h1:SU73FTI9D1P5UNtvseffFSGmdNci/O6RsqzeXJtP0Qs=
+github.com/alecthomas/repr v0.5.2/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
github.com/andybalholm/cascadia v1.3.2/go.mod h1:7gtRlve5FxPPgIgX36uWBX58OdBsSS6lUvCFb+h7KvU=
github.com/andybalholm/cascadia v1.3.3 h1:AG2YHrzJIm4BZ19iwJ/DAua6Btl3IwJX+VI4kktS1LM=
github.com/andybalholm/cascadia v1.3.3/go.mod h1:xNd9bqTn98Ln4DwST8/nG+H0yuB8Hmgu1YHNnWw0GeA=
@@ -92,8 +92,8 @@ github.com/charlievieth/fastwalk v1.0.14 h1:3Eh5uaFGwHZd8EGwTjJnSpBkfwfsak9h6ICg
github.com/charlievieth/fastwalk v1.0.14/go.mod h1:diVcUreiU1aQ4/Wu3NbxxH4/KYdKpLDojrQ1Bb2KgNY=
github.com/charmbracelet/anthropic-sdk-go v0.0.0-20251024181547-21d6f3d9a904 h1:rwLdEpG9wE6kL69KkEKDiWprO8pQOZHZXeod6+9K+mw=
github.com/charmbracelet/anthropic-sdk-go v0.0.0-20251024181547-21d6f3d9a904/go.mod h1:8TIYxZxsuCqqeJ0lga/b91tBwrbjoHDC66Sq5t8N2R4=
-github.com/charmbracelet/catwalk v0.11.0 h1:PU3rkc4h4YVJEn9Iyb/1rQAaF4hEd04fuG4tj3vv4dg=
-github.com/charmbracelet/catwalk v0.11.0/go.mod h1:qg+Yl9oaZTkTvRscqbxfttzOFQ4v0pOT5XwC7b5O0NQ=
+github.com/charmbracelet/catwalk v0.11.2 h1:m+eE7yv/uIrKW95FpFeGDMFrAugotylX89XzpkZwlLk=
+github.com/charmbracelet/catwalk v0.11.2/go.mod h1:qg+Yl9oaZTkTvRscqbxfttzOFQ4v0pOT5XwC7b5O0NQ=
github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk=
github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk=
github.com/charmbracelet/fang v0.4.4 h1:G4qKxF6or/eTPgmAolwPuRNyuci3hTUGGX1rj1YkHJY=
@@ -248,8 +248,8 @@ github.com/muesli/roff v0.1.0 h1:YD0lalCotmYuF5HhZliKWlIx7IEhiXeSfq7hNjFqGF8=
github.com/muesli/roff v0.1.0/go.mod h1:pjAHQM9hdUUwm/krAfrLGgJkXJ+YuhtsfZ42kieB2Ig=
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
-github.com/ncruces/go-sqlite3 v0.30.3 h1:X/CgWW9GzmIAkEPrifhKqf0cC15DuOVxAJaHFTTAURQ=
-github.com/ncruces/go-sqlite3 v0.30.3/go.mod h1:AxKu9sRxkludimFocbktlY6LiYSkxiI5gTA8r+os/Nw=
+github.com/ncruces/go-sqlite3 v0.30.4 h1:j9hEoOL7f9ZoXl8uqXVniaq1VNwlWAXihZbTvhqPPjA=
+github.com/ncruces/go-sqlite3 v0.30.4/go.mod h1:7WR20VSC5IZusKhUdiR9y1NsUqnZgqIYCmKKoMEYg68=
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M=
@@ -310,8 +310,8 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
-github.com/tetratelabs/wazero v1.10.1 h1:2DugeJf6VVk58KTPszlNfeeN8AhhpwcZqkJj2wwFuH8=
-github.com/tetratelabs/wazero v1.10.1/go.mod h1:DRm5twOQ5Gr1AoEdSi0CLjDQF1J9ZAuyqFIjl1KKfQU=
+github.com/tetratelabs/wazero v1.11.0 h1:+gKemEuKCTevU4d7ZTzlsvgd1uaToIDtlQlmNbwqYhA=
+github.com/tetratelabs/wazero v1.11.0/go.mod h1:eV28rsN8Q+xwjogd7f4/Pp4xFxO7uOGbLcD/LzB1wiU=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
@@ -171,10 +171,9 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
var wg sync.WaitGroup
// Generate title if first message.
if len(msgs) == 0 {
+ titleCtx := ctx // Copy to avoid race with ctx reassignment below.
wg.Go(func() {
- sessionLock.Lock()
- a.generateTitle(ctx, ¤tSession, call.Prompt)
- sessionLock.Unlock()
+ a.generateTitle(titleCtx, call.SessionID, call.Prompt)
})
}
@@ -201,7 +200,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
var currentAssistant *message.Message
var shouldSummarize bool
result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
- Prompt: call.Prompt,
+ Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
Files: files,
Messages: history,
ProviderOptions: call.ProviderOptions,
@@ -650,11 +649,11 @@ func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
}
func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
+ parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
var attachmentParts []message.ContentPart
for _, attachment := range call.Attachments {
attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
}
- parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
parts = append(parts, attachmentParts...)
msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
Role: message.User,
@@ -691,6 +690,9 @@ If not, please feel free to ignore. Again do not mention this message to the use
var files []fantasy.FilePart
for _, attachment := range attachments {
+ if attachment.IsText() {
+ continue
+ }
files = append(files, fantasy.FilePart{
Filename: attachment.FileName,
Data: attachment.Content,
@@ -723,7 +725,7 @@ func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.S
return msgs, nil
}
-func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
+func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, prompt string) {
if prompt == "" {
return
}
@@ -768,8 +770,7 @@ func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Sessi
return
}
- session.Title = title
-
+ // Calculate usage and cost.
var openrouterCost *float64
for _, step := range resp.Steps {
stepCost := a.openrouterCost(step.ProviderMetadata)
@@ -782,8 +783,27 @@ func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Sessi
}
}
- a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
- _, saveErr := a.sessions.Save(ctx, *session)
+ modelConfig := a.smallModel.CatwalkCfg
+ cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
+ modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
+ modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
+ modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
+
+ if a.isClaudeCode() {
+ cost = 0
+ }
+
+ // Use override cost if available (e.g., from OpenRouter).
+ if openrouterCost != nil {
+ cost = *openrouterCost
+ }
+
+ promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
+ completionTokens := resp.TotalUsage.OutputTokens + resp.TotalUsage.CacheReadTokens
+
+ // Atomically update only title and usage fields to avoid overriding other
+ // concurrent session updates.
+ saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
if saveErr != nil {
slog.Error("failed to save session title & usage", "error", saveErr)
return
@@ -178,6 +178,10 @@ func coderAgent(r *vcr.Recorder, env fakeEnv, large, small fantasy.LanguageModel
GeneratedWith: true,
}
+ // Clear skills paths to ensure test reproducibility - user's skills
+ // would be included in prompt and break VCR cassette matching.
+ cfg.Options.SkillsPaths = []string{}
+
systemPrompt, err := prompt.Build(context.TODO(), large.Provider(), large.Model(), *cfg)
if err != nil {
return nil, err
@@ -123,7 +123,14 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string,
}
if !model.CatwalkCfg.SupportsImages && attachments != nil {
- attachments = nil
+ // filter out image attachments
+ filteredAttachments := make([]message.Attachment, 0, len(attachments))
+ for _, att := range attachments {
+ if att.IsText() {
+ filteredAttachments = append(filteredAttachments, att)
+ }
+ }
+ attachments = filteredAttachments
}
providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
@@ -383,7 +390,7 @@ func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fan
tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
tools.NewSourcegraphTool(nil),
tools.NewTodosTool(c.sessions),
- tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir()),
+ tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir(), c.cfg.Options.SkillsPaths...),
tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
)
@@ -516,14 +523,13 @@ func (c *coordinator) buildAgentModels(ctx context.Context) (Model, Model, error
}, nil
}
-func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
+func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string, isOauth bool) (fantasy.Provider, error) {
var opts []anthropic.Option
- if strings.HasPrefix(apiKey, "Bearer ") {
+ if isOauth {
// NOTE: Prevent the SDK from picking up the API key from env.
os.Setenv("ANTHROPIC_API_KEY", "")
-
- headers["Authorization"] = apiKey
+ headers["Authorization"] = fmt.Sprintf("Bearer %s", apiKey)
} else if apiKey != "" {
// X-Api-Key header
opts = append(opts, anthropic.WithAPIKey(apiKey))
@@ -541,7 +547,6 @@ func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map
httpClient := log.NewHTTPClient()
opts = append(opts, anthropic.WithHTTPClient(httpClient))
}
-
return anthropic.New(opts...)
}
@@ -722,7 +727,7 @@ func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model con
case openai.Name:
return c.buildOpenaiProvider(baseURL, apiKey, headers)
case anthropic.Name:
- return c.buildAnthropicProvider(baseURL, apiKey, headers)
+ return c.buildAnthropicProvider(baseURL, apiKey, headers, providerCfg.OAuthToken != nil)
case openrouter.Name:
return c.buildOpenrouterProvider(baseURL, apiKey, headers)
case azure.Name:
@@ -14,6 +14,7 @@ import (
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/home"
"github.com/charmbracelet/crush/internal/shell"
+ "github.com/charmbracelet/crush/internal/skills"
)
// Prompt represents a template-based prompt generator.
@@ -26,15 +27,16 @@ type Prompt struct {
}
type PromptDat struct {
- Provider string
- Model string
- Config config.Config
- WorkingDir string
- IsGitRepo bool
- Platform string
- Date string
- GitStatus string
- ContextFiles []ContextFile
+ Provider string
+ Model string
+ Config config.Config
+ WorkingDir string
+ IsGitRepo bool
+ Platform string
+ Date string
+ GitStatus string
+ ContextFiles []ContextFile
+ AvailSkillXML string
}
type ContextFile struct {
@@ -162,15 +164,28 @@ func (p *Prompt) promptData(ctx context.Context, provider, model string, cfg con
files[pathKey] = content
}
+ // Discover and load skills metadata.
+ var availSkillXML string
+ if len(cfg.Options.SkillsPaths) > 0 {
+ expandedPaths := make([]string, 0, len(cfg.Options.SkillsPaths))
+ for _, pth := range cfg.Options.SkillsPaths {
+ expandedPaths = append(expandedPaths, expandPath(pth, cfg))
+ }
+ if discoveredSkills := skills.Discover(expandedPaths); len(discoveredSkills) > 0 {
+ availSkillXML = skills.ToPromptXML(discoveredSkills)
+ }
+ }
+
isGit := isGitRepo(cfg.WorkingDir())
data := PromptDat{
- Provider: provider,
- Model: model,
- Config: cfg,
- WorkingDir: filepath.ToSlash(workingDir),
- IsGitRepo: isGit,
- Platform: platform,
- Date: p.now().Format("1/2/2006"),
+ Provider: provider,
+ Model: model,
+ Config: cfg,
+ WorkingDir: filepath.ToSlash(workingDir),
+ IsGitRepo: isGit,
+ Platform: platform,
+ Date: p.now().Format("1/2/2006"),
+ AvailSkillXML: availSkillXML,
}
if isGit {
var err error
@@ -360,6 +360,16 @@ Diagnostics (lint/typecheck) included in tool output.
- Ignore issues in files you didn't touch (unless user asks)
</lsp>
{{end}}
+{{- if .AvailSkillXML}}
+
+{{.AvailSkillXML}}
+
+<skills_usage>
+When a user task matches a skill's description, read the skill's SKILL.md file to get full instructions.
+Skills are activated by reading their location path. Follow the skill's instructions to complete the task.
+If a skill mentions scripts, references, or assets, they are placed in the same folder as the skill itself (e.g., scripts/, references/, assets/ subdirectories within the skill's folder).
+</skills_usage>
+{{end}}
{{if .ContextFiles}}
<memory>
@@ -38,6 +38,7 @@ type viewTool struct {
lspClients *csync.Map[string, *lsp.Client]
workingDir string
permissions permission.Service
+ skillsPaths []string
}
type ViewResponseMetadata struct {
@@ -52,7 +53,7 @@ const (
MaxLineLength = 2000
)
-func NewViewTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, workingDir string) fantasy.AgentTool {
+func NewViewTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, workingDir string, skillsPaths ...string) fantasy.AgentTool {
return fantasy.NewAgentTool(
ViewToolName,
string(viewDescription),
@@ -76,8 +77,11 @@ func NewViewTool(lspClients *csync.Map[string, *lsp.Client], permissions permiss
}
relPath, err := filepath.Rel(absWorkingDir, absFilePath)
- if err != nil || strings.HasPrefix(relPath, "..") {
- // File is outside working directory, request permission
+ isOutsideWorkDir := err != nil || strings.HasPrefix(relPath, "..")
+ isSkillFile := isInSkillsPath(absFilePath, skillsPaths)
+
+ // Request permission for files outside working directory, unless it's a skill file.
+ if isOutsideWorkDir && !isSkillFile {
sessionID := GetSessionFromContext(ctx)
if sessionID == "" {
return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for accessing files outside working directory")
@@ -137,19 +141,23 @@ func NewViewTool(lspClients *csync.Map[string, *lsp.Client], permissions permiss
return fantasy.NewTextErrorResponse(fmt.Sprintf("Path is a directory, not a file: %s", filePath)), nil
}
- // Check file size
- if fileInfo.Size() > MaxReadSize {
+ // Based on the specifications we should not limit the skills read.
+ if !isSkillFile && fileInfo.Size() > MaxReadSize {
return fantasy.NewTextErrorResponse(fmt.Sprintf("File is too large (%d bytes). Maximum size is %d bytes",
fileInfo.Size(), MaxReadSize)), nil
}
- // Set default limit if not provided
+ // Set default limit if not provided (no limit for SKILL.md files)
if params.Limit <= 0 {
- params.Limit = DefaultReadLimit
+ if isSkillFile {
+ params.Limit = 1000000 // Effectively no limit for skill files
+ } else {
+ params.Limit = DefaultReadLimit
+ }
}
- isImage, mimeType := getImageMimeType(filePath)
- if isImage {
+ isSupportedImage, mimeType := getImageMimeType(filePath)
+ if isSupportedImage {
if !GetSupportsImagesFromContext(ctx) {
modelName := GetModelNameFromContext(ctx)
return fantasy.NewTextErrorResponse(fmt.Sprintf("This model (%s) does not support image data.", modelName)), nil
@@ -282,10 +290,6 @@ func getImageMimeType(filePath string) (bool, string) {
return true, "image/png"
case ".gif":
return true, "image/gif"
- case ".bmp":
- return true, "image/bmp"
- case ".svg":
- return true, "image/svg+xml"
case ".webp":
return true, "image/webp"
default:
@@ -319,3 +323,44 @@ func (s *LineScanner) Text() string {
func (s *LineScanner) Err() error {
return s.scanner.Err()
}
+
+// isInSkillsPath checks if filePath is within any of the configured skills
+// directories. Returns true for files that can be read without permission
+// prompts and without size limits.
+//
+// Note that symlinks are resolved to prevent path traversal attacks via
+// symbolic links.
+func isInSkillsPath(filePath string, skillsPaths []string) bool {
+ if len(skillsPaths) == 0 {
+ return false
+ }
+
+ absFilePath, err := filepath.Abs(filePath)
+ if err != nil {
+ return false
+ }
+
+ evalFilePath, err := filepath.EvalSymlinks(absFilePath)
+ if err != nil {
+ return false
+ }
+
+ for _, skillsPath := range skillsPaths {
+ absSkillsPath, err := filepath.Abs(skillsPath)
+ if err != nil {
+ continue
+ }
+
+ evalSkillsPath, err := filepath.EvalSymlinks(absSkillsPath)
+ if err != nil {
+ continue
+ }
+
+ relPath, err := filepath.Rel(evalSkillsPath, evalFilePath)
+ if err == nil && !strings.HasPrefix(relPath, "..") {
+ return true
+ }
+ }
+
+ return false
+}
@@ -69,7 +69,7 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) {
messages := message.NewService(q)
files := history.NewService(q, conn)
skipPermissionsRequests := cfg.Permissions != nil && cfg.Permissions.SkipRequests
- allowedTools := []string{}
+ var allowedTools []string
if cfg.Permissions != nil && cfg.Permissions.AllowedTools != nil {
allowedTools = cfg.Permissions.AllowedTools
}
@@ -415,7 +415,7 @@ func (app *App) Shutdown() {
})
}
- // Call call cleanup functions.
+ // Call all cleanup functions.
for _, cleanup := range app.cleanupFuncs {
if cleanup != nil {
wg.Go(func() {
@@ -38,7 +38,7 @@ func (app *App) createAndStartLSPClient(ctx context.Context, name string, config
// Create LSP client.
lspClient, err := lsp.New(ctx, name, config, app.config.Resolver())
if err != nil {
- slog.Error("Failed to create LSP client for", name, err)
+ slog.Error("Failed to create LSP client for", "name", name, "error", err)
updateLSPState(name, lsp.StateError, err, nil, 0)
return
}
@@ -75,8 +75,7 @@ func loginHyper() error {
if !hyperp.Enabled() {
return fmt.Errorf("hyper not enabled")
}
- ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
- defer cancel()
+ ctx := getLoginContext()
resp, err := hyper.InitiateDeviceAuth(ctx)
if err != nil {
@@ -9,6 +9,7 @@ import (
"net/http"
"net/url"
"os"
+ "path/filepath"
"slices"
"strings"
"time"
@@ -255,6 +256,7 @@ func (Attribution) JSONSchemaExtend(schema *jsonschema.Schema) {
type Options struct {
ContextPaths []string `json:"context_paths,omitempty" jsonschema:"description=Paths to files containing context information for the AI,example=.cursorrules,example=CRUSH.md"`
+ SkillsPaths []string `json:"skills_paths,omitempty" jsonschema:"description=Paths to directories containing Agent Skills (folders with SKILL.md files),example=~/.config/crush/skills,example=./skills"`
TUI *TUIOptions `json:"tui,omitempty" jsonschema:"description=Terminal user interface options"`
Debug bool `json:"debug,omitempty" jsonschema:"description=Enable debug logging,default=false"`
DebugLSP bool `json:"debug_lsp,omitempty" jsonschema:"description=Enable debug logging for LSP servers,default=false"`
@@ -498,7 +500,6 @@ func (c *Config) HasConfigField(key string) bool {
}
func (c *Config) SetConfigField(key string, value any) error {
- // read the data
data, err := os.ReadFile(c.dataConfigDir)
if err != nil {
if os.IsNotExist(err) {
@@ -512,6 +513,9 @@ func (c *Config) SetConfigField(key string, value any) error {
if err != nil {
return fmt.Errorf("failed to set config field %s: %w", key, err)
}
+ if err := os.MkdirAll(filepath.Dir(c.dataConfigDir), 0o755); err != nil {
+ return fmt.Errorf("failed to create config directory %q: %w", c.dataConfigDir, err)
+ }
if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o600); err != nil {
return fmt.Errorf("failed to write config file: %w", err)
}
@@ -547,13 +551,12 @@ func (c *Config) RefreshOAuthToken(ctx context.Context, providerID string) error
slog.Info("Successfully refreshed OAuth token", "provider", providerID)
providerConfig.OAuthToken = newToken
+ providerConfig.APIKey = newToken.AccessToken
switch providerID {
case string(catwalk.InferenceProviderAnthropic):
- providerConfig.APIKey = fmt.Sprintf("Bearer %s", newToken.AccessToken)
providerConfig.SetupClaudeCode()
case string(catwalk.InferenceProviderCopilot):
- providerConfig.APIKey = newToken.AccessToken
providerConfig.SetupGitHubCopilot()
}
@@ -592,7 +595,6 @@ func (c *Config) SetProviderAPIKey(providerID string, apiKey any) error {
providerConfig.OAuthToken = v
switch providerID {
case string(catwalk.InferenceProviderAnthropic):
- providerConfig.APIKey = fmt.Sprintf("Bearer %s", v.AccessToken)
providerConfig.SetupClaudeCode()
case string(catwalk.InferenceProviderCopilot):
providerConfig.SetupGitHubCopilot()
@@ -6,11 +6,12 @@ import (
"log/slog"
"testing"
+ "github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/oauth"
"github.com/charmbracelet/crush/internal/oauth/copilot"
)
-func (c *Config) importCopilot() (*oauth.Token, bool) {
+func (c *Config) ImportCopilot() (*oauth.Token, bool) {
if testing.Testing() {
return nil, false
}
@@ -31,6 +32,10 @@ func (c *Config) importCopilot() (*oauth.Token, bool) {
return nil, false
}
+ if err := c.SetProviderAPIKey(string(catwalk.InferenceProviderCopilot), token); err != nil {
+ return token, false
+ }
+
if err := cmp.Or(
c.SetConfigField("providers.copilot.api_key", token.AccessToken),
c.SetConfigField("providers.copilot.oauth", token),
@@ -133,8 +133,6 @@ func PushPopCrushEnv() func() {
}
func (c *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []catwalk.Provider) error {
- c.importCopilot()
-
knownProviderNames := make(map[string]bool)
restore := PushPopCrushEnv()
defer restore()
@@ -206,11 +204,6 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
case p.ID == catwalk.InferenceProviderAnthropic && config.OAuthToken != nil:
prepared.SetupClaudeCode()
case p.ID == catwalk.InferenceProviderCopilot:
- if config.OAuthToken != nil {
- if token, ok := c.importCopilot(); ok {
- prepared.OAuthToken = token
- }
- }
if config.OAuthToken != nil {
prepared.SetupGitHubCopilot()
}
@@ -336,6 +329,9 @@ func (c *Config) setDefaults(workingDir, dataDir string) {
if c.Options.ContextPaths == nil {
c.Options.ContextPaths = []string{}
}
+ if c.Options.SkillsPaths == nil {
+ c.Options.SkillsPaths = []string{}
+ }
if dataDir != "" {
c.Options.DataDirectory = dataDir
} else if c.Options.DataDirectory == "" {
@@ -369,6 +365,12 @@ func (c *Config) setDefaults(workingDir, dataDir string) {
slices.Sort(c.Options.ContextPaths)
c.Options.ContextPaths = slices.Compact(c.Options.ContextPaths)
+ // Add the default skills directory if not already present.
+ defaultSkillsDir := GlobalSkillsDir()
+ if !slices.Contains(c.Options.SkillsPaths, defaultSkillsDir) {
+ c.Options.SkillsPaths = append([]string{defaultSkillsDir}, c.Options.SkillsPaths...)
+ }
+
if str, ok := os.LookupEnv("CRUSH_DISABLE_PROVIDER_AUTO_UPDATE"); ok {
c.Options.DisableProviderAutoUpdate, _ = strconv.ParseBool(str)
}
@@ -743,3 +745,25 @@ func isInsideWorktree() bool {
).CombinedOutput()
return err == nil && strings.TrimSpace(string(bts)) == "true"
}
+
+// GlobalSkillsDir returns the default directory for Agent Skills.
+// Skills in this directory are auto-discovered and their files can be read
+// without permission prompts.
+func GlobalSkillsDir() string {
+ if crushSkills := os.Getenv("CRUSH_SKILLS_DIR"); crushSkills != "" {
+ return crushSkills
+ }
+ if xdgConfigHome := os.Getenv("XDG_CONFIG_HOME"); xdgConfigHome != "" {
+ return filepath.Join(xdgConfigHome, appName, "skills")
+ }
+
+ if runtime.GOOS == "windows" {
+ localAppData := cmp.Or(
+ os.Getenv("LOCALAPPDATA"),
+ filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local"),
+ )
+ return filepath.Join(localAppData, appName, "skills")
+ }
+
+ return filepath.Join(home.Dir(), ".config", appName, "skills")
+}
@@ -84,6 +84,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) {
if q.updateSessionStmt, err = db.PrepareContext(ctx, updateSession); err != nil {
return nil, fmt.Errorf("error preparing query UpdateSession: %w", err)
}
+ if q.updateSessionTitleAndUsageStmt, err = db.PrepareContext(ctx, updateSessionTitleAndUsage); err != nil {
+ return nil, fmt.Errorf("error preparing query UpdateSessionTitleAndUsage: %w", err)
+ }
return &q, nil
}
@@ -189,6 +192,11 @@ func (q *Queries) Close() error {
err = fmt.Errorf("error closing updateSessionStmt: %w", cerr)
}
}
+ if q.updateSessionTitleAndUsageStmt != nil {
+ if cerr := q.updateSessionTitleAndUsageStmt.Close(); cerr != nil {
+ err = fmt.Errorf("error closing updateSessionTitleAndUsageStmt: %w", cerr)
+ }
+ }
return err
}
@@ -226,53 +234,55 @@ func (q *Queries) queryRow(ctx context.Context, stmt *sql.Stmt, query string, ar
}
type Queries struct {
- db DBTX
- tx *sql.Tx
- createFileStmt *sql.Stmt
- createMessageStmt *sql.Stmt
- createSessionStmt *sql.Stmt
- deleteFileStmt *sql.Stmt
- deleteMessageStmt *sql.Stmt
- deleteSessionStmt *sql.Stmt
- deleteSessionFilesStmt *sql.Stmt
- deleteSessionMessagesStmt *sql.Stmt
- getFileStmt *sql.Stmt
- getFileByPathAndSessionStmt *sql.Stmt
- getMessageStmt *sql.Stmt
- getSessionByIDStmt *sql.Stmt
- listFilesByPathStmt *sql.Stmt
- listFilesBySessionStmt *sql.Stmt
- listLatestSessionFilesStmt *sql.Stmt
- listMessagesBySessionStmt *sql.Stmt
- listNewFilesStmt *sql.Stmt
- listSessionsStmt *sql.Stmt
- updateMessageStmt *sql.Stmt
- updateSessionStmt *sql.Stmt
+ db DBTX
+ tx *sql.Tx
+ createFileStmt *sql.Stmt
+ createMessageStmt *sql.Stmt
+ createSessionStmt *sql.Stmt
+ deleteFileStmt *sql.Stmt
+ deleteMessageStmt *sql.Stmt
+ deleteSessionStmt *sql.Stmt
+ deleteSessionFilesStmt *sql.Stmt
+ deleteSessionMessagesStmt *sql.Stmt
+ getFileStmt *sql.Stmt
+ getFileByPathAndSessionStmt *sql.Stmt
+ getMessageStmt *sql.Stmt
+ getSessionByIDStmt *sql.Stmt
+ listFilesByPathStmt *sql.Stmt
+ listFilesBySessionStmt *sql.Stmt
+ listLatestSessionFilesStmt *sql.Stmt
+ listMessagesBySessionStmt *sql.Stmt
+ listNewFilesStmt *sql.Stmt
+ listSessionsStmt *sql.Stmt
+ updateMessageStmt *sql.Stmt
+ updateSessionStmt *sql.Stmt
+ updateSessionTitleAndUsageStmt *sql.Stmt
}
func (q *Queries) WithTx(tx *sql.Tx) *Queries {
return &Queries{
- db: tx,
- tx: tx,
- createFileStmt: q.createFileStmt,
- createMessageStmt: q.createMessageStmt,
- createSessionStmt: q.createSessionStmt,
- deleteFileStmt: q.deleteFileStmt,
- deleteMessageStmt: q.deleteMessageStmt,
- deleteSessionStmt: q.deleteSessionStmt,
- deleteSessionFilesStmt: q.deleteSessionFilesStmt,
- deleteSessionMessagesStmt: q.deleteSessionMessagesStmt,
- getFileStmt: q.getFileStmt,
- getFileByPathAndSessionStmt: q.getFileByPathAndSessionStmt,
- getMessageStmt: q.getMessageStmt,
- getSessionByIDStmt: q.getSessionByIDStmt,
- listFilesByPathStmt: q.listFilesByPathStmt,
- listFilesBySessionStmt: q.listFilesBySessionStmt,
- listLatestSessionFilesStmt: q.listLatestSessionFilesStmt,
- listMessagesBySessionStmt: q.listMessagesBySessionStmt,
- listNewFilesStmt: q.listNewFilesStmt,
- listSessionsStmt: q.listSessionsStmt,
- updateMessageStmt: q.updateMessageStmt,
- updateSessionStmt: q.updateSessionStmt,
+ db: tx,
+ tx: tx,
+ createFileStmt: q.createFileStmt,
+ createMessageStmt: q.createMessageStmt,
+ createSessionStmt: q.createSessionStmt,
+ deleteFileStmt: q.deleteFileStmt,
+ deleteMessageStmt: q.deleteMessageStmt,
+ deleteSessionStmt: q.deleteSessionStmt,
+ deleteSessionFilesStmt: q.deleteSessionFilesStmt,
+ deleteSessionMessagesStmt: q.deleteSessionMessagesStmt,
+ getFileStmt: q.getFileStmt,
+ getFileByPathAndSessionStmt: q.getFileByPathAndSessionStmt,
+ getMessageStmt: q.getMessageStmt,
+ getSessionByIDStmt: q.getSessionByIDStmt,
+ listFilesByPathStmt: q.listFilesByPathStmt,
+ listFilesBySessionStmt: q.listFilesBySessionStmt,
+ listLatestSessionFilesStmt: q.listLatestSessionFilesStmt,
+ listMessagesBySessionStmt: q.listMessagesBySessionStmt,
+ listNewFilesStmt: q.listNewFilesStmt,
+ listSessionsStmt: q.listSessionsStmt,
+ updateMessageStmt: q.updateMessageStmt,
+ updateSessionStmt: q.updateSessionStmt,
+ updateSessionTitleAndUsageStmt: q.updateSessionTitleAndUsageStmt,
}
}
@@ -29,6 +29,7 @@ type Querier interface {
ListSessions(ctx context.Context) ([]Session, error)
UpdateMessage(ctx context.Context, arg UpdateMessageParams) error
UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error)
+ UpdateSessionTitleAndUsage(ctx context.Context, arg UpdateSessionTitleAndUsageParams) error
}
var _ Querier = (*Queries)(nil)
@@ -199,3 +199,32 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S
)
return i, err
}
+
+const updateSessionTitleAndUsage = `-- name: UpdateSessionTitleAndUsage :exec
+UPDATE sessions
+SET
+ title = ?,
+ prompt_tokens = prompt_tokens + ?,
+ completion_tokens = completion_tokens + ?,
+ cost = cost + ?
+WHERE id = ?
+`
+
+type UpdateSessionTitleAndUsageParams struct {
+ Title string `json:"title"`
+ PromptTokens int64 `json:"prompt_tokens"`
+ CompletionTokens int64 `json:"completion_tokens"`
+ Cost float64 `json:"cost"`
+ ID string `json:"id"`
+}
+
+func (q *Queries) UpdateSessionTitleAndUsage(ctx context.Context, arg UpdateSessionTitleAndUsageParams) error {
+ _, err := q.exec(ctx, q.updateSessionTitleAndUsageStmt, updateSessionTitleAndUsage,
+ arg.Title,
+ arg.PromptTokens,
+ arg.CompletionTokens,
+ arg.Cost,
+ arg.ID,
+ )
+ return err
+}
@@ -46,6 +46,15 @@ SET
WHERE id = ?
RETURNING *;
+-- name: UpdateSessionTitleAndUsage :exec
+UPDATE sessions
+SET
+ title = ?,
+ prompt_tokens = prompt_tokens + ?,
+ completion_tokens = completion_tokens + ?,
+ cost = cost + ?
+WHERE id = ?;
+
-- name: DeleteSession :exec
DELETE FROM sessions
@@ -1,8 +1,13 @@
package message
+import "strings"
+
type Attachment struct {
FilePath string
FileName string
MimeType string
Content []byte
}
+
+func (a Attachment) IsText() bool { return strings.HasPrefix(a.MimeType, "text/") }
+func (a Attachment) IsImage() bool { return strings.HasPrefix(a.MimeType, "image/") }
@@ -3,6 +3,7 @@ package message
import (
"encoding/base64"
"errors"
+ "fmt"
"slices"
"strings"
"time"
@@ -407,6 +408,15 @@ func (m *Message) SetToolResults(tr []ToolResult) {
}
}
+// Clone returns a deep copy of the message with an independent Parts slice.
+// This prevents race conditions when the message is modified concurrently.
+func (m *Message) Clone() Message {
+ clone := *m
+ clone.Parts = make([]ContentPart, len(m.Parts))
+ copy(clone.Parts, m.Parts)
+ return clone
+}
+
func (m *Message) AddFinish(reason FinishReason, message, details string) {
// remove any existing finish part
for i, part := range m.Parts {
@@ -426,16 +436,52 @@ func (m *Message) AddBinary(mimeType string, data []byte) {
m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data})
}
+func PromptWithTextAttachments(prompt string, attachments []Attachment) string {
+ addedAttachments := false
+ for _, content := range attachments {
+ if !content.IsText() {
+ continue
+ }
+ if !addedAttachments {
+ prompt += "\n<system_info>The files below have been attached by the user, consider them in your response</system_info>\n"
+ addedAttachments = true
+ }
+ tag := `<file>\n`
+ if content.FilePath != "" {
+ tag = fmt.Sprintf("<file path='%s'>\n", content.FilePath)
+ }
+ prompt += tag
+ prompt += "\n" + string(content.Content) + "\n</file>\n"
+ }
+ return prompt
+}
+
func (m *Message) ToAIMessage() []fantasy.Message {
var messages []fantasy.Message
switch m.Role {
case User:
var parts []fantasy.MessagePart
text := strings.TrimSpace(m.Content().Text)
+ var textAttachments []Attachment
+ for _, content := range m.BinaryContent() {
+ if !strings.HasPrefix(content.MIMEType, "text/") {
+ continue
+ }
+ textAttachments = append(textAttachments, Attachment{
+ FilePath: content.Path,
+ MimeType: content.MIMEType,
+ Content: content.Data,
+ })
+ }
+ text = PromptWithTextAttachments(text, textAttachments)
if text != "" {
parts = append(parts, fantasy.TextPart{Text: text})
}
for _, content := range m.BinaryContent() {
+ // skip text attachements
+ if strings.HasPrefix(content.MIMEType, "text/") {
+ continue
+ }
parts = append(parts, fantasy.FilePart{
Filename: content.Path,
Data: content.Data,
@@ -51,7 +51,9 @@ func (s *service) Delete(ctx context.Context, id string) error {
if err != nil {
return err
}
- s.Publish(pubsub.DeletedEvent, message)
+ // Clone the message before publishing to avoid race conditions with
+ // concurrent modifications to the Parts slice.
+ s.Publish(pubsub.DeletedEvent, message.Clone())
return nil
}
@@ -85,7 +87,9 @@ func (s *service) Create(ctx context.Context, sessionID string, params CreateMes
if err != nil {
return Message{}, err
}
- s.Publish(pubsub.CreatedEvent, message)
+ // Clone the message before publishing to avoid race conditions with
+ // concurrent modifications to the Parts slice.
+ s.Publish(pubsub.CreatedEvent, message.Clone())
return message, nil
}
@@ -124,7 +128,9 @@ func (s *service) Update(ctx context.Context, message Message) error {
return err
}
message.UpdatedAt = time.Now().Unix()
- s.Publish(pubsub.UpdatedEvent, message)
+ // Clone the message before publishing to avoid race conditions with
+ // concurrent modifications to the Parts slice.
+ s.Publish(pubsub.UpdatedEvent, message.Clone())
return nil
}
@@ -92,22 +92,17 @@ func (b *Broker[T]) GetSubscriberCount() int {
func (b *Broker[T]) Publish(t EventType, payload T) {
b.mu.RLock()
+ defer b.mu.RUnlock()
+
select {
case <-b.done:
- b.mu.RUnlock()
return
default:
}
- subscribers := make([]chan Event[T], 0, len(b.subs))
- for sub := range b.subs {
- subscribers = append(subscribers, sub)
- }
- b.mu.RUnlock()
-
event := Event[T]{Type: t, Payload: payload}
- for _, sub := range subscribers {
+ for sub := range b.subs {
select {
case sub <- event:
default:
@@ -50,6 +50,7 @@ type Service interface {
Get(ctx context.Context, id string) (Session, error)
List(ctx context.Context) ([]Session, error)
Save(ctx context.Context, session Session) (Session, error)
+ UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error
Delete(ctx context.Context, id string) error
// Agent tool session management
@@ -156,6 +157,18 @@ func (s *service) Save(ctx context.Context, session Session) (Session, error) {
return session, nil
}
+// UpdateTitleAndUsage updates only the title and usage fields atomically.
+// This is safer than fetching, modifying, and saving the entire session.
+func (s *service) UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error {
+ return s.q.UpdateSessionTitleAndUsage(ctx, db.UpdateSessionTitleAndUsageParams{
+ ID: sessionID,
+ Title: title,
+ PromptTokens: promptTokens,
+ CompletionTokens: completionTokens,
+ Cost: cost,
+ })
+}
+
func (s *service) List(ctx context.Context) ([]Session, error) {
dbSessions, err := s.q.ListSessions(ctx)
if err != nil {
@@ -0,0 +1,164 @@
+// Package skills implements the Agent Skills open standard.
+// See https://agentskills.io for the specification.
+package skills
+
+import (
+ "errors"
+ "fmt"
+ "log/slog"
+ "os"
+ "path/filepath"
+ "regexp"
+ "strings"
+
+ "gopkg.in/yaml.v3"
+)
+
+const (
+ SkillFileName = "SKILL.md"
+ MaxNameLength = 64
+ MaxDescriptionLength = 1024
+ MaxCompatibilityLength = 500
+)
+
+var namePattern = regexp.MustCompile(`^[a-zA-Z0-9]+(-[a-zA-Z0-9]+)*$`)
+
+// Skill represents a parsed SKILL.md file.
+type Skill struct {
+ Name string `yaml:"name" json:"name"`
+ Description string `yaml:"description" json:"description"`
+ License string `yaml:"license,omitempty" json:"license,omitempty"`
+ Compatibility string `yaml:"compatibility,omitempty" json:"compatibility,omitempty"`
+ Metadata map[string]string `yaml:"metadata,omitempty" json:"metadata,omitempty"`
+ Instructions string `yaml:"-" json:"instructions"`
+ Path string `yaml:"-" json:"path"`
+ SkillFilePath string `yaml:"-" json:"skill_file_path"`
+}
+
+// Validate checks if the skill meets spec requirements.
+func (s *Skill) Validate() error {
+ var errs []error
+
+ if s.Name == "" {
+ errs = append(errs, errors.New("name is required"))
+ } else {
+ if len(s.Name) > MaxNameLength {
+ errs = append(errs, fmt.Errorf("name exceeds %d characters", MaxNameLength))
+ }
+ if !namePattern.MatchString(s.Name) {
+ errs = append(errs, errors.New("name must be alphanumeric with hyphens, no leading/trailing/consecutive hyphens"))
+ }
+ if s.Path != "" && !strings.EqualFold(filepath.Base(s.Path), s.Name) {
+ errs = append(errs, fmt.Errorf("name %q must match directory %q", s.Name, filepath.Base(s.Path)))
+ }
+ }
+
+ if s.Description == "" {
+ errs = append(errs, errors.New("description is required"))
+ } else if len(s.Description) > MaxDescriptionLength {
+ errs = append(errs, fmt.Errorf("description exceeds %d characters", MaxDescriptionLength))
+ }
+
+ if len(s.Compatibility) > MaxCompatibilityLength {
+ errs = append(errs, fmt.Errorf("compatibility exceeds %d characters", MaxCompatibilityLength))
+ }
+
+ return errors.Join(errs...)
+}
+
+// Parse parses a SKILL.md file.
+func Parse(path string) (*Skill, error) {
+ content, err := os.ReadFile(path)
+ if err != nil {
+ return nil, err
+ }
+
+ frontmatter, body, err := splitFrontmatter(string(content))
+ if err != nil {
+ return nil, err
+ }
+
+ var skill Skill
+ if err := yaml.Unmarshal([]byte(frontmatter), &skill); err != nil {
+ return nil, fmt.Errorf("parsing frontmatter: %w", err)
+ }
+
+ skill.Instructions = strings.TrimSpace(body)
+ skill.Path = filepath.Dir(path)
+ skill.SkillFilePath = path
+
+ return &skill, nil
+}
+
+// splitFrontmatter extracts YAML frontmatter and body from markdown content.
+func splitFrontmatter(content string) (frontmatter, body string, err error) {
+ // Normalize line endings to \n for consistent parsing.
+ content = strings.ReplaceAll(content, "\r\n", "\n")
+ if !strings.HasPrefix(content, "---\n") {
+ return "", "", errors.New("no YAML frontmatter found")
+ }
+
+ rest := strings.TrimPrefix(content, "---\n")
+ before, after, ok := strings.Cut(rest, "\n---")
+ if !ok {
+ return "", "", errors.New("unclosed frontmatter")
+ }
+
+ return before, after, nil
+}
+
+// Discover finds all valid skills in the given paths.
+func Discover(paths []string) []*Skill {
+ var skills []*Skill
+ seen := make(map[string]bool)
+
+ for _, base := range paths {
+ filepath.WalkDir(base, func(path string, d os.DirEntry, err error) error {
+ if err != nil {
+ return nil
+ }
+ if d.IsDir() || d.Name() != SkillFileName || seen[path] {
+ return nil
+ }
+ seen[path] = true
+ skill, err := Parse(path)
+ if err != nil {
+ slog.Warn("Failed to parse skill file", "path", path, "error", err)
+ return nil
+ }
+ if err := skill.Validate(); err != nil {
+ slog.Warn("Skill validation failed", "path", path, "error", err)
+ return nil
+ }
+ slog.Info("Successfully loaded skill", "name", skill.Name, "path", path)
+ skills = append(skills, skill)
+ return nil
+ })
+ }
+
+ return skills
+}
+
+// ToPromptXML generates XML for injection into the system prompt.
+func ToPromptXML(skills []*Skill) string {
+ if len(skills) == 0 {
+ return ""
+ }
+
+ var sb strings.Builder
+ sb.WriteString("<available_skills>\n")
+ for _, s := range skills {
+ sb.WriteString(" <skill>\n")
+ fmt.Fprintf(&sb, " <name>%s</name>\n", escape(s.Name))
+ fmt.Fprintf(&sb, " <description>%s</description>\n", escape(s.Description))
+ fmt.Fprintf(&sb, " <location>%s</location>\n", escape(s.SkillFilePath))
+ sb.WriteString(" </skill>\n")
+ }
+ sb.WriteString("</available_skills>")
+ return sb.String()
+}
+
+func escape(s string) string {
+ r := strings.NewReplacer("&", "&", "<", "<", ">", ">", "\"", """, "'", "'")
+ return r.Replace(s)
+}
@@ -0,0 +1,249 @@
+package skills
+
+import (
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestParse(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ content string
+ wantName string
+ wantDesc string
+ wantLicense string
+ wantCompat string
+ wantMeta map[string]string
+ wantTools string
+ wantInstr string
+ wantErr bool
+ }{
+ {
+ name: "full skill",
+ content: `---
+name: pdf-processing
+description: Extracts text and tables from PDF files, fills PDF forms, and merges multiple PDFs.
+license: Apache-2.0
+compatibility: Requires python 3.8+, pdfplumber, pdfrw libraries
+metadata:
+ author: example-org
+ version: "1.0"
+---
+
+# PDF Processing
+
+## When to use this skill
+Use this skill when the user needs to work with PDF files.
+`,
+ wantName: "pdf-processing",
+ wantDesc: "Extracts text and tables from PDF files, fills PDF forms, and merges multiple PDFs.",
+ wantLicense: "Apache-2.0",
+ wantCompat: "Requires python 3.8+, pdfplumber, pdfrw libraries",
+ wantMeta: map[string]string{"author": "example-org", "version": "1.0"},
+ wantInstr: "# PDF Processing\n\n## When to use this skill\nUse this skill when the user needs to work with PDF files.",
+ },
+ {
+ name: "minimal skill",
+ content: `---
+name: my-skill
+description: A simple skill for testing.
+---
+
+# My Skill
+
+Instructions here.
+`,
+ wantName: "my-skill",
+ wantDesc: "A simple skill for testing.",
+ wantInstr: "# My Skill\n\nInstructions here.",
+ },
+ {
+ name: "no frontmatter",
+ content: "# Just Markdown\n\nNo frontmatter here.",
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Write content to temp file.
+ dir := t.TempDir()
+ path := filepath.Join(dir, "SKILL.md")
+ require.NoError(t, os.WriteFile(path, []byte(tt.content), 0o644))
+
+ skill, err := Parse(path)
+ if tt.wantErr {
+ require.Error(t, err)
+ return
+ }
+ require.NoError(t, err)
+
+ require.Equal(t, tt.wantName, skill.Name)
+ require.Equal(t, tt.wantDesc, skill.Description)
+ require.Equal(t, tt.wantLicense, skill.License)
+ require.Equal(t, tt.wantCompat, skill.Compatibility)
+
+ if tt.wantMeta != nil {
+ require.Equal(t, tt.wantMeta, skill.Metadata)
+ }
+
+ require.Equal(t, tt.wantInstr, skill.Instructions)
+ })
+ }
+}
+
+func TestSkillValidate(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ skill Skill
+ wantErr bool
+ errMsg string
+ }{
+ {
+ name: "valid skill",
+ skill: Skill{
+ Name: "pdf-processing",
+ Description: "Processes PDF files.",
+ Path: "/skills/pdf-processing",
+ },
+ },
+ {
+ name: "missing name",
+ skill: Skill{Description: "Some description."},
+ wantErr: true,
+ errMsg: "name is required",
+ },
+ {
+ name: "missing description",
+ skill: Skill{Name: "my-skill", Path: "/skills/my-skill"},
+ wantErr: true,
+ errMsg: "description is required",
+ },
+ {
+ name: "name too long",
+ skill: Skill{Name: strings.Repeat("a", 65), Description: "Some description."},
+ wantErr: true,
+ errMsg: "exceeds",
+ },
+ {
+ name: "valid name - mixed case",
+ skill: Skill{Name: "MySkill", Description: "Some description.", Path: "/skills/MySkill"},
+ wantErr: false,
+ },
+ {
+ name: "invalid name - starts with hyphen",
+ skill: Skill{Name: "-my-skill", Description: "Some description."},
+ wantErr: true,
+ errMsg: "alphanumeric with hyphens",
+ },
+ {
+ name: "name doesn't match directory",
+ skill: Skill{Name: "my-skill", Description: "Some description.", Path: "/skills/other-skill"},
+ wantErr: true,
+ errMsg: "must match directory",
+ },
+ {
+ name: "description too long",
+ skill: Skill{Name: "my-skill", Description: strings.Repeat("a", 1025), Path: "/skills/my-skill"},
+ wantErr: true,
+ errMsg: "description exceeds",
+ },
+ {
+ name: "compatibility too long",
+ skill: Skill{Name: "my-skill", Description: "desc", Compatibility: strings.Repeat("a", 501), Path: "/skills/my-skill"},
+ wantErr: true,
+ errMsg: "compatibility exceeds",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ err := tt.skill.Validate()
+ if tt.wantErr {
+ require.Error(t, err)
+ require.Contains(t, err.Error(), tt.errMsg)
+ } else {
+ require.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestDiscover(t *testing.T) {
+ t.Parallel()
+
+ tmpDir := t.TempDir()
+
+ // Create valid skill 1.
+ skill1Dir := filepath.Join(tmpDir, "skill-one")
+ require.NoError(t, os.MkdirAll(skill1Dir, 0o755))
+ require.NoError(t, os.WriteFile(filepath.Join(skill1Dir, "SKILL.md"), []byte(`---
+name: skill-one
+description: First test skill.
+---
+# Skill One
+`), 0o644))
+
+ // Create valid skill 2 in nested directory.
+ skill2Dir := filepath.Join(tmpDir, "nested", "skill-two")
+ require.NoError(t, os.MkdirAll(skill2Dir, 0o755))
+ require.NoError(t, os.WriteFile(filepath.Join(skill2Dir, "SKILL.md"), []byte(`---
+name: skill-two
+description: Second test skill.
+---
+# Skill Two
+`), 0o644))
+
+ // Create invalid skill (won't be included).
+ invalidDir := filepath.Join(tmpDir, "invalid-dir")
+ require.NoError(t, os.MkdirAll(invalidDir, 0o755))
+ require.NoError(t, os.WriteFile(filepath.Join(invalidDir, "SKILL.md"), []byte(`---
+name: wrong-name
+description: Name doesn't match directory.
+---
+`), 0o644))
+
+ skills := Discover([]string{tmpDir})
+ require.Len(t, skills, 2)
+
+ names := make(map[string]bool)
+ for _, s := range skills {
+ names[s.Name] = true
+ }
+ require.True(t, names["skill-one"])
+ require.True(t, names["skill-two"])
+}
+
+func TestToPromptXML(t *testing.T) {
+ t.Parallel()
+
+ skills := []*Skill{
+ {Name: "pdf-processing", Description: "Extracts text from PDFs.", SkillFilePath: "/skills/pdf-processing/SKILL.md"},
+ {Name: "data-analysis", Description: "Analyzes datasets & charts.", SkillFilePath: "/skills/data-analysis/SKILL.md"},
+ }
+
+ xml := ToPromptXML(skills)
+
+ require.Contains(t, xml, "<available_skills>")
+ require.Contains(t, xml, "<name>pdf-processing</name>")
+ require.Contains(t, xml, "<description>Extracts text from PDFs.</description>")
+ require.Contains(t, xml, "&") // XML escaping
+}
+
+func TestToPromptXMLEmpty(t *testing.T) {
+ t.Parallel()
+ require.Empty(t, ToPromptXML(nil))
+ require.Empty(t, ToPromptXML([]*Skill{}))
+}
@@ -2,6 +2,7 @@ package editor
import (
"context"
+ "errors"
"fmt"
"math/rand"
"net/http"
@@ -29,6 +30,7 @@ import (
"github.com/charmbracelet/crush/internal/tui/components/dialogs/quit"
"github.com/charmbracelet/crush/internal/tui/styles"
"github.com/charmbracelet/crush/internal/tui/util"
+ "github.com/charmbracelet/x/ansi"
)
type Editor interface {
@@ -84,10 +86,7 @@ var DeleteKeyMaps = DeleteAttachmentKeyMaps{
),
}
-const (
- maxAttachments = 5
- maxFileResults = 25
-)
+const maxFileResults = 25
type OpenEditorMsg struct {
Text string
@@ -145,14 +144,14 @@ func (m *editorCmp) send() tea.Cmd {
return util.CmdHandler(dialogs.OpenDialogMsg{Model: quit.NewQuitDialog()})
}
- m.textarea.Reset()
attachments := m.attachments
- m.attachments = nil
if value == "" {
return nil
}
+ m.textarea.Reset()
+ m.attachments = nil
// Change the placeholder when sending a new message.
m.randomizePlaceholders()
@@ -176,9 +175,6 @@ func (m *editorCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
case tea.WindowSizeMsg:
return m, m.repositionCompletions
case filepicker.FilePickedMsg:
- if len(m.attachments) >= maxAttachments {
- return m, util.ReportError(fmt.Errorf("cannot add more than %d images", maxAttachments))
- }
m.attachments = append(m.attachments, msg.Attachment)
return m, nil
case completions.CompletionsOpenedMsg:
@@ -206,6 +202,17 @@ func (m *editorCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
m.currentQuery = ""
m.completionsStartIndex = 0
}
+ content, err := os.ReadFile(item.Path)
+ if err != nil {
+ // if it fails, let the LLM handle it later.
+ return m, nil
+ }
+ m.attachments = append(m.attachments, message.Attachment{
+ FilePath: item.Path,
+ FileName: filepath.Base(item.Path),
+ MimeType: mimeOf(content),
+ Content: content,
+ })
}
case commands.OpenExternalEditorMsg:
@@ -217,39 +224,29 @@ func (m *editorCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
m.textarea.SetValue(msg.Text)
m.textarea.MoveToEnd()
case tea.PasteMsg:
- path := strings.ReplaceAll(msg.Content, "\\ ", " ")
- // try to get an image
- path, err := filepath.Abs(strings.TrimSpace(path))
- if err != nil {
+ content, path, err := pasteToFile(msg)
+ if errors.Is(err, errNotAFile) {
m.textarea, cmd = m.textarea.Update(msg)
return m, cmd
}
- isAllowedType := false
- for _, ext := range filepicker.AllowedTypes {
- if strings.HasSuffix(path, ext) {
- isAllowedType = true
- break
- }
- }
- if !isAllowedType {
- m.textarea, cmd = m.textarea.Update(msg)
- return m, cmd
+ if err != nil {
+ return m, util.ReportError(err)
}
- tooBig, _ := filepicker.IsFileTooBig(path, filepicker.MaxAttachmentSize)
- if tooBig {
- m.textarea, cmd = m.textarea.Update(msg)
- return m, cmd
+
+ if len(content) > maxAttachmentSize {
+ return m, util.ReportWarn("File is too big (>5mb)")
}
- content, err := os.ReadFile(path)
- if err != nil {
- m.textarea, cmd = m.textarea.Update(msg)
- return m, cmd
+ mimeType := mimeOf(content)
+ attachment := message.Attachment{
+ FilePath: path,
+ FileName: filepath.Base(path),
+ MimeType: mimeType,
+ Content: content,
+ }
+ if !attachment.IsText() && !attachment.IsImage() {
+ return m, util.ReportWarn("Invalid file content type: " + mimeType)
}
- mimeBufferSize := min(512, len(content))
- mimeType := http.DetectContentType(content[:mimeBufferSize])
- fileName := filepath.Base(path)
- attachment := message.Attachment{FilePath: path, FileName: fileName, MimeType: mimeType, Content: content}
return m, util.CmdHandler(filepicker.FilePickedMsg{
Attachment: attachment,
})
@@ -427,18 +424,17 @@ func (m *editorCmp) View() string {
m.textarea.Placeholder = "Yolo mode!"
}
if len(m.attachments) == 0 {
- content := t.S().Base.Padding(1).Render(
+ return t.S().Base.Padding(1).Render(
m.textarea.View(),
)
- return content
}
- content := t.S().Base.Padding(0, 1, 1, 1).Render(
- lipgloss.JoinVertical(lipgloss.Top,
+ return t.S().Base.Padding(0, 1, 1, 1).Render(
+ lipgloss.JoinVertical(
+ lipgloss.Top,
m.attachmentsContent(),
m.textarea.View(),
),
)
- return content
}
func (m *editorCmp) SetSize(width, height int) tea.Cmd {
@@ -456,24 +452,45 @@ func (m *editorCmp) GetSize() (int, int) {
func (m *editorCmp) attachmentsContent() string {
var styledAttachments []string
t := styles.CurrentTheme()
- attachmentStyles := t.S().Base.
- MarginLeft(1).
+ attachmentStyle := t.S().Base.
+ Padding(0, 1).
+ MarginRight(1).
Background(t.FgMuted).
- Foreground(t.FgBase)
+ Foreground(t.FgBase).
+ Render
+ iconStyle := t.S().Base.
+ Foreground(t.BgSubtle).
+ Background(t.Green).
+ Padding(0, 1).
+ Bold(true).
+ Render
+ rmStyle := t.S().Base.
+ Padding(0, 1).
+ Bold(true).
+ Background(t.Red).
+ Foreground(t.FgBase).
+ Render
for i, attachment := range m.attachments {
- var filename string
- if len(attachment.FileName) > 10 {
- filename = fmt.Sprintf(" %s %s...", styles.DocumentIcon, attachment.FileName[0:7])
- } else {
- filename = fmt.Sprintf(" %s %s", styles.DocumentIcon, attachment.FileName)
+ filename := ansi.Truncate(filepath.Base(attachment.FileName), 10, "...")
+ icon := styles.ImageIcon
+ if attachment.IsText() {
+ icon = styles.TextIcon
}
if m.deleteMode {
- filename = fmt.Sprintf("%d%s", i, filename)
+ styledAttachments = append(
+ styledAttachments,
+ rmStyle(fmt.Sprintf("%d", i)),
+ attachmentStyle(filename),
+ )
+ continue
}
- styledAttachments = append(styledAttachments, attachmentStyles.Render(filename))
+ styledAttachments = append(
+ styledAttachments,
+ iconStyle(icon),
+ attachmentStyle(filename),
+ )
}
- content := lipgloss.JoinHorizontal(lipgloss.Left, styledAttachments...)
- return content
+ return lipgloss.JoinHorizontal(lipgloss.Left, styledAttachments...)
}
func (m *editorCmp) SetPosition(x, y int) tea.Cmd {
@@ -597,3 +614,51 @@ func New(app *app.App) Editor {
return e
}
+
+var maxAttachmentSize = 5 * 1024 * 1024 // 5MB
+
+var errNotAFile = errors.New("not a file")
+
+func pasteToFile(msg tea.PasteMsg) ([]byte, string, error) {
+ content, path, err := filepathToFile(msg.Content)
+ if err == nil {
+ return content, path, err
+ }
+
+ if strings.Count(msg.Content, "\n") > 2 {
+ return contentToFile([]byte(msg.Content))
+ }
+
+ return nil, "", errNotAFile
+}
+
+func contentToFile(content []byte) ([]byte, string, error) {
+ f, err := os.CreateTemp("", "paste_*.txt")
+ if err != nil {
+ return nil, "", err
+ }
+ if _, err := f.Write(content); err != nil {
+ return nil, "", err
+ }
+ if err := f.Close(); err != nil {
+ return nil, "", err
+ }
+ return content, f.Name(), nil
+}
+
+func filepathToFile(name string) ([]byte, string, error) {
+ path, err := filepath.Abs(strings.TrimSpace(strings.ReplaceAll(name, "\\", "")))
+ if err != nil {
+ return nil, "", err
+ }
+ content, err := os.ReadFile(path)
+ if err != nil {
+ return nil, "", err
+ }
+ return content, path, nil
+}
+
+func mimeOf(content []byte) string {
+ mimeBufferSize := min(512, len(content))
+ return http.DetectContentType(content[:mimeBufferSize])
+}
@@ -227,19 +227,32 @@ func (m *messageCmp) renderUserMessage() string {
m.toMarkdown(m.message.Content().String()),
}
- attachmentStyles := t.S().Text.
- MarginLeft(1).
- Background(t.BgSubtle)
+ attachmentStyle := t.S().Base.
+ Padding(0, 1).
+ MarginRight(1).
+ Background(t.FgMuted).
+ Foreground(t.FgBase).
+ Render
+ iconStyle := t.S().Base.
+ Foreground(t.BgSubtle).
+ Background(t.Green).
+ Padding(0, 1).
+ Bold(true).
+ Render
attachments := make([]string, len(m.message.BinaryContent()))
for i, attachment := range m.message.BinaryContent() {
const maxFilenameWidth = 10
- filename := filepath.Base(attachment.Path)
- attachments[i] = attachmentStyles.Render(fmt.Sprintf(
- " %s %s ",
- styles.DocumentIcon,
- ansi.Truncate(filename, maxFilenameWidth, "..."),
- ))
+ filename := ansi.Truncate(filepath.Base(attachment.Path), 10, "...")
+ icon := styles.ImageIcon
+ if strings.HasPrefix(attachment.MIMEType, "text/") {
+ icon = styles.TextIcon
+ }
+ attachments[i] = lipgloss.JoinHorizontal(
+ lipgloss.Left,
+ iconStyle(icon),
+ attachmentStyle(filename),
+ )
}
if len(attachments) > 0 {
@@ -12,12 +12,15 @@ import (
"github.com/atotto/clipboard"
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/agent"
+ hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/home"
"github.com/charmbracelet/crush/internal/tui/components/chat"
"github.com/charmbracelet/crush/internal/tui/components/core"
"github.com/charmbracelet/crush/internal/tui/components/core/layout"
"github.com/charmbracelet/crush/internal/tui/components/dialogs/claude"
+ "github.com/charmbracelet/crush/internal/tui/components/dialogs/copilot"
+ "github.com/charmbracelet/crush/internal/tui/components/dialogs/hyper"
"github.com/charmbracelet/crush/internal/tui/components/dialogs/models"
"github.com/charmbracelet/crush/internal/tui/components/logo"
lspcomponent "github.com/charmbracelet/crush/internal/tui/components/lsp"
@@ -55,6 +58,12 @@ type Splash interface {
// IsClaudeOAuthComplete returns whether Claude OAuth flow is complete
IsClaudeOAuthComplete() bool
+
+ // IsShowingClaudeOAuth2 returns whether showing Hyper OAuth2 flow
+ IsShowingHyperOAuth2() bool
+
+ // IsShowingClaudeOAuth2 returns whether showing GitHub Copilot OAuth2 flow
+ IsShowingCopilotOAuth2() bool
}
const (
@@ -87,6 +96,14 @@ type splashCmp struct {
isAPIKeyValid bool
apiKeyValue string
+ // Hyper device flow state
+ hyperDeviceFlow *hyper.DeviceFlow
+ showHyperDeviceFlow bool
+
+ // Copilot device flow state
+ copilotDeviceFlow *copilot.DeviceFlow
+ showCopilotDeviceFlow bool
+
// Claude state
claudeAuthMethodChooser *claude.AuthMethodChooser
claudeOAuth2 *claude.OAuth2
@@ -186,6 +203,26 @@ func (s *splashCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
}
return s, tea.Batch(cmds...)
+ case hyper.DeviceFlowCompletedMsg:
+ s.showHyperDeviceFlow = false
+ return s, s.saveAPIKeyAndContinue(msg.Token, true)
+ case hyper.DeviceAuthInitiatedMsg, hyper.DeviceFlowErrorMsg:
+ if s.hyperDeviceFlow != nil {
+ u, cmd := s.hyperDeviceFlow.Update(msg)
+ s.hyperDeviceFlow = u.(*hyper.DeviceFlow)
+ return s, cmd
+ }
+ return s, nil
+ case copilot.DeviceAuthInitiatedMsg, copilot.DeviceFlowErrorMsg:
+ if s.copilotDeviceFlow != nil {
+ u, cmd := s.copilotDeviceFlow.Update(msg)
+ s.copilotDeviceFlow = u.(*copilot.DeviceFlow)
+ return s, cmd
+ }
+ return s, nil
+ case copilot.DeviceFlowCompletedMsg:
+ s.showCopilotDeviceFlow = false
+ return s, s.saveAPIKeyAndContinue(msg.Token, true)
case claude.AuthenticationCompleteMsg:
s.showClaudeAuthMethodChooser = false
s.showClaudeOAuth2 = false
@@ -205,41 +242,49 @@ func (s *splashCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
}
case tea.KeyPressMsg:
switch {
- case key.Matches(msg, s.keyMap.Copy):
- if s.showClaudeOAuth2 && s.claudeOAuth2.State == claude.OAuthStateURL {
- return s, tea.Sequence(
- tea.SetClipboard(s.claudeOAuth2.URL),
- func() tea.Msg {
- _ = clipboard.WriteAll(s.claudeOAuth2.URL)
- return nil
- },
- util.ReportInfo("URL copied to clipboard"),
- )
- } else if s.showClaudeAuthMethodChooser {
- u, cmd := s.claudeAuthMethodChooser.Update(msg)
- s.claudeAuthMethodChooser = u.(*claude.AuthMethodChooser)
- return s, cmd
- } else if s.showClaudeOAuth2 {
- u, cmd := s.claudeOAuth2.Update(msg)
- s.claudeOAuth2 = u.(*claude.OAuth2)
- return s, cmd
- }
+ case key.Matches(msg, s.keyMap.Copy) && s.showHyperDeviceFlow:
+ return s, s.hyperDeviceFlow.CopyCode()
+ case key.Matches(msg, s.keyMap.Copy) && s.showCopilotDeviceFlow:
+ return s, s.copilotDeviceFlow.CopyCode()
+ case key.Matches(msg, s.keyMap.Copy) && s.showClaudeOAuth2 && s.claudeOAuth2.State == claude.OAuthStateURL:
+ return s, tea.Sequence(
+ tea.SetClipboard(s.claudeOAuth2.URL),
+ func() tea.Msg {
+ _ = clipboard.WriteAll(s.claudeOAuth2.URL)
+ return nil
+ },
+ util.ReportInfo("URL copied to clipboard"),
+ )
+ case key.Matches(msg, s.keyMap.Copy) && s.showClaudeAuthMethodChooser:
+ u, cmd := s.claudeAuthMethodChooser.Update(msg)
+ s.claudeAuthMethodChooser = u.(*claude.AuthMethodChooser)
+ return s, cmd
+ case key.Matches(msg, s.keyMap.Copy) && s.showClaudeOAuth2:
+ u, cmd := s.claudeOAuth2.Update(msg)
+ s.claudeOAuth2 = u.(*claude.OAuth2)
+ return s, cmd
case key.Matches(msg, s.keyMap.Back):
- if s.showClaudeAuthMethodChooser {
+ switch {
+ case s.showClaudeAuthMethodChooser:
s.claudeAuthMethodChooser.SetDefaults()
s.showClaudeAuthMethodChooser = false
return s, nil
- }
- if s.showClaudeOAuth2 {
+ case s.showClaudeOAuth2:
s.claudeOAuth2.SetDefaults()
s.showClaudeOAuth2 = false
s.showClaudeAuthMethodChooser = true
return s, nil
- }
- if s.isAPIKeyValid {
+ case s.showHyperDeviceFlow:
+ s.hyperDeviceFlow = nil
+ s.showHyperDeviceFlow = false
return s, nil
- }
- if s.needsAPIKey {
+ case s.showCopilotDeviceFlow:
+ s.copilotDeviceFlow = nil
+ s.showCopilotDeviceFlow = false
+ return s, nil
+ case s.isAPIKeyValid:
+ return s, nil
+ case s.needsAPIKey:
if s.selectedModel.Provider.ID == catwalk.InferenceProviderAnthropic {
s.showClaudeAuthMethodChooser = true
}
@@ -251,7 +296,8 @@ func (s *splashCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
return s, nil
}
case key.Matches(msg, s.keyMap.Select):
- if s.showClaudeAuthMethodChooser {
+ switch {
+ case s.showClaudeAuthMethodChooser:
selectedItem := s.modelList.SelectedModel()
if selectedItem == nil {
return s, nil
@@ -269,16 +315,17 @@ func (s *splashCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
s.showClaudeOAuth2 = true
}
return s, nil
- }
- if s.showClaudeOAuth2 {
+ case s.showClaudeOAuth2:
m2, cmd2 := s.claudeOAuth2.ValidationConfirm()
s.claudeOAuth2 = m2.(*claude.OAuth2)
return s, cmd2
- }
- if s.isAPIKeyValid {
+ case s.showHyperDeviceFlow:
+ return s, s.hyperDeviceFlow.CopyCodeAndOpenURL()
+ case s.showCopilotDeviceFlow:
+ return s, s.copilotDeviceFlow.CopyCodeAndOpenURL()
+ case s.isAPIKeyValid:
return s, s.saveAPIKeyAndContinue(s.apiKeyValue, true)
- }
- if s.isOnboarding && !s.needsAPIKey {
+ case s.isOnboarding && !s.needsAPIKey:
selectedItem := s.modelList.SelectedModel()
if selectedItem == nil {
return s, nil
@@ -288,9 +335,26 @@ func (s *splashCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
s.isOnboarding = false
return s, tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
} else {
- if selectedItem.Provider.ID == catwalk.InferenceProviderAnthropic {
+ switch selectedItem.Provider.ID {
+ case catwalk.InferenceProviderAnthropic:
s.showClaudeAuthMethodChooser = true
return s, nil
+ case hyperp.Name:
+ s.selectedModel = selectedItem
+ s.showHyperDeviceFlow = true
+ s.hyperDeviceFlow = hyper.NewDeviceFlow()
+ s.hyperDeviceFlow.SetWidth(min(s.width-2, 60))
+ return s, s.hyperDeviceFlow.Init()
+ case catwalk.InferenceProviderCopilot:
+ if token, ok := config.Get().ImportCopilot(); ok {
+ s.selectedModel = selectedItem
+ return s, s.saveAPIKeyAndContinue(token, true)
+ }
+ s.selectedModel = selectedItem
+ s.showCopilotDeviceFlow = true
+ s.copilotDeviceFlow = copilot.NewDeviceFlow()
+ s.copilotDeviceFlow.SetWidth(min(s.width-2, 60))
+ return s, s.copilotDeviceFlow.Init()
}
// Provider not configured, show API key input
s.needsAPIKey = true
@@ -298,7 +362,7 @@ func (s *splashCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
s.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
return s, nil
}
- } else if s.needsAPIKey {
+ case s.needsAPIKey:
// Handle API key submission
s.apiKeyValue = strings.TrimSpace(s.apiKeyInput.Value())
if s.apiKeyValue == "" {
@@ -339,7 +403,7 @@ func (s *splashCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
}
},
)
- } else if s.needsProjectInit {
+ case s.needsProjectInit:
return s, s.initializeProject()
}
case key.Matches(msg, s.keyMap.Tab, s.keyMap.LeftRight):
@@ -387,44 +451,71 @@ func (s *splashCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
return s, s.initializeProject()
}
default:
- if s.showClaudeAuthMethodChooser {
+ switch {
+ case s.showClaudeAuthMethodChooser:
u, cmd := s.claudeAuthMethodChooser.Update(msg)
s.claudeAuthMethodChooser = u.(*claude.AuthMethodChooser)
return s, cmd
- } else if s.showClaudeOAuth2 {
+ case s.showClaudeOAuth2:
u, cmd := s.claudeOAuth2.Update(msg)
s.claudeOAuth2 = u.(*claude.OAuth2)
return s, cmd
- } else if s.needsAPIKey {
+ case s.showHyperDeviceFlow:
+ u, cmd := s.hyperDeviceFlow.Update(msg)
+ s.hyperDeviceFlow = u.(*hyper.DeviceFlow)
+ return s, cmd
+ case s.showCopilotDeviceFlow:
+ u, cmd := s.copilotDeviceFlow.Update(msg)
+ s.copilotDeviceFlow = u.(*copilot.DeviceFlow)
+ return s, cmd
+ case s.needsAPIKey:
u, cmd := s.apiKeyInput.Update(msg)
s.apiKeyInput = u.(*models.APIKeyInput)
return s, cmd
- } else if s.isOnboarding {
+ case s.isOnboarding:
u, cmd := s.modelList.Update(msg)
s.modelList = u
return s, cmd
}
}
case tea.PasteMsg:
- if s.showClaudeOAuth2 {
+ switch {
+ case s.showClaudeOAuth2:
u, cmd := s.claudeOAuth2.Update(msg)
s.claudeOAuth2 = u.(*claude.OAuth2)
return s, cmd
- } else if s.needsAPIKey {
+ case s.showHyperDeviceFlow:
+ u, cmd := s.hyperDeviceFlow.Update(msg)
+ s.hyperDeviceFlow = u.(*hyper.DeviceFlow)
+ return s, cmd
+ case s.showCopilotDeviceFlow:
+ u, cmd := s.copilotDeviceFlow.Update(msg)
+ s.copilotDeviceFlow = u.(*copilot.DeviceFlow)
+ return s, cmd
+ case s.needsAPIKey:
u, cmd := s.apiKeyInput.Update(msg)
s.apiKeyInput = u.(*models.APIKeyInput)
return s, cmd
- } else if s.isOnboarding {
+ case s.isOnboarding:
var cmd tea.Cmd
s.modelList, cmd = s.modelList.Update(msg)
return s, cmd
}
case spinner.TickMsg:
- if s.showClaudeOAuth2 {
+ switch {
+ case s.showClaudeOAuth2:
u, cmd := s.claudeOAuth2.Update(msg)
s.claudeOAuth2 = u.(*claude.OAuth2)
return s, cmd
- } else {
+ case s.showHyperDeviceFlow:
+ u, cmd := s.hyperDeviceFlow.Update(msg)
+ s.hyperDeviceFlow = u.(*hyper.DeviceFlow)
+ return s, cmd
+ case s.showCopilotDeviceFlow:
+ u, cmd := s.copilotDeviceFlow.Update(msg)
+ s.copilotDeviceFlow = u.(*copilot.DeviceFlow)
+ return s, cmd
+ default:
u, cmd := s.apiKeyInput.Update(msg)
s.apiKeyInput = u.(*models.APIKeyInput)
return s, cmd
@@ -562,8 +653,10 @@ func (s *splashCmp) isProviderConfigured(providerID string) bool {
func (s *splashCmp) View() string {
t := styles.CurrentTheme()
var content string
- if s.showClaudeAuthMethodChooser {
- remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
+
+ switch {
+ case s.showClaudeAuthMethodChooser:
+ remainingHeight := s.height - lipgloss.Height(s.logoRendered) - SplashScreenPaddingY
chooserView := s.claudeAuthMethodChooser.View()
authMethodSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
lipgloss.JoinVertical(
@@ -578,8 +671,8 @@ func (s *splashCmp) View() string {
s.logoRendered,
authMethodSelector,
)
- } else if s.showClaudeOAuth2 {
- remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
+ case s.showClaudeOAuth2:
+ remainingHeight := s.height - lipgloss.Height(s.logoRendered) - SplashScreenPaddingY
oauth2View := s.claudeOAuth2.View()
oauthSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
lipgloss.JoinVertical(
@@ -594,8 +687,38 @@ func (s *splashCmp) View() string {
s.logoRendered,
oauthSelector,
)
- } else if s.needsAPIKey {
- remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
+ case s.showHyperDeviceFlow:
+ remainingHeight := s.height - lipgloss.Height(s.logoRendered) - SplashScreenPaddingY
+ hyperView := s.hyperDeviceFlow.View()
+ hyperSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
+ lipgloss.JoinVertical(
+ lipgloss.Left,
+ t.S().Base.PaddingLeft(1).Foreground(t.Primary).Render("Let's Auth Hyper"),
+ hyperView,
+ ),
+ )
+ content = lipgloss.JoinVertical(
+ lipgloss.Left,
+ s.logoRendered,
+ hyperSelector,
+ )
+ case s.showCopilotDeviceFlow:
+ remainingHeight := s.height - lipgloss.Height(s.logoRendered) - SplashScreenPaddingY
+ copilotView := s.copilotDeviceFlow.View()
+ copilotSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
+ lipgloss.JoinVertical(
+ lipgloss.Left,
+ t.S().Base.PaddingLeft(1).Foreground(t.Primary).Render("Let's Auth GitHub Copilot"),
+ copilotView,
+ ),
+ )
+ content = lipgloss.JoinVertical(
+ lipgloss.Left,
+ s.logoRendered,
+ copilotSelector,
+ )
+ case s.needsAPIKey:
+ remainingHeight := s.height - lipgloss.Height(s.logoRendered) - SplashScreenPaddingY
apiKeyView := t.S().Base.PaddingLeft(1).Render(s.apiKeyInput.View())
apiKeySelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
lipgloss.JoinVertical(
@@ -608,9 +731,9 @@ func (s *splashCmp) View() string {
s.logoRendered,
apiKeySelector,
)
- } else if s.isOnboarding {
+ case s.isOnboarding:
modelListView := s.modelList.View()
- remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
+ remainingHeight := s.height - lipgloss.Height(s.logoRendered) - SplashScreenPaddingY
modelSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
lipgloss.JoinVertical(
lipgloss.Left,
@@ -624,7 +747,7 @@ func (s *splashCmp) View() string {
s.logoRendered,
modelSelector,
)
- } else if s.needsProjectInit {
+ case s.needsProjectInit:
titleStyle := t.S().Base.Foreground(t.FgBase)
pathStyle := t.S().Base.Foreground(t.Success).PaddingLeft(2)
bodyStyle := t.S().Base.Foreground(t.FgMuted)
@@ -675,7 +798,7 @@ func (s *splashCmp) View() string {
"",
initContent,
)
- } else {
+ default:
parts := []string{
s.logoRendered,
s.infoSection(),
@@ -692,28 +815,25 @@ func (s *splashCmp) View() string {
}
func (s *splashCmp) Cursor() *tea.Cursor {
- if s.showClaudeAuthMethodChooser {
+ switch {
+ case s.showClaudeAuthMethodChooser:
return nil
- }
- if s.showClaudeOAuth2 {
+ case s.showClaudeOAuth2:
if cursor := s.claudeOAuth2.CodeInput.Cursor(); cursor != nil {
cursor.Y += 2 // FIXME(@andreynering): Why do we need this?
return s.moveCursor(cursor)
}
return nil
- }
- if s.needsAPIKey {
+ case s.needsAPIKey:
cursor := s.apiKeyInput.Cursor()
if cursor != nil {
return s.moveCursor(cursor)
}
- } else if s.isOnboarding {
+ case s.isOnboarding:
cursor := s.modelList.Cursor()
if cursor != nil {
return s.moveCursor(cursor)
}
- } else {
- return nil
}
return nil
}
@@ -805,13 +925,14 @@ func (s *splashCmp) logoGap() int {
// Bindings implements SplashPage.
func (s *splashCmp) Bindings() []key.Binding {
- if s.showClaudeAuthMethodChooser {
+ switch {
+ case s.showClaudeAuthMethodChooser:
return []key.Binding{
s.keyMap.Select,
s.keyMap.Tab,
s.keyMap.Back,
}
- } else if s.showClaudeOAuth2 {
+ case s.showClaudeOAuth2:
bindings := []key.Binding{
s.keyMap.Select,
}
@@ -819,18 +940,18 @@ func (s *splashCmp) Bindings() []key.Binding {
bindings = append(bindings, s.keyMap.Copy)
}
return bindings
- } else if s.needsAPIKey {
+ case s.needsAPIKey:
return []key.Binding{
s.keyMap.Select,
s.keyMap.Back,
}
- } else if s.isOnboarding {
+ case s.isOnboarding:
return []key.Binding{
s.keyMap.Select,
s.keyMap.Next,
s.keyMap.Previous,
}
- } else if s.needsProjectInit {
+ case s.needsProjectInit:
return []key.Binding{
s.keyMap.Select,
s.keyMap.Yes,
@@ -838,8 +959,9 @@ func (s *splashCmp) Bindings() []key.Binding {
s.keyMap.Tab,
s.keyMap.LeftRight,
}
+ default:
+ return []key.Binding{}
}
- return []key.Binding{}
}
func (s *splashCmp) getMaxInfoWidth() int {
@@ -940,3 +1062,11 @@ func (s *splashCmp) IsClaudeOAuthURLState() bool {
func (s *splashCmp) IsClaudeOAuthComplete() bool {
return s.showClaudeOAuth2 && s.claudeOAuth2.State == claude.OAuthStateCode && s.claudeOAuth2.ValidationState == claude.OAuthValidationStateValid
}
+
+func (s *splashCmp) IsShowingHyperOAuth2() bool {
+ return s.showHyperDeviceFlow
+}
+
+func (s *splashCmp) IsShowingCopilotOAuth2() bool {
+ return s.showCopilotDeviceFlow
+}
@@ -174,13 +174,10 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
case tea.KeyPressMsg:
switch {
// Handle Hyper device flow keys
- case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && (m.showHyperDeviceFlow || m.showCopilotDeviceFlow):
- if m.hyperDeviceFlow != nil {
- return m, m.hyperDeviceFlow.CopyCode()
- }
- if m.copilotDeviceFlow != nil {
- return m, m.copilotDeviceFlow.CopyCode()
- }
+ case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showHyperDeviceFlow:
+ return m, m.hyperDeviceFlow.CopyCode()
+ case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showCopilotDeviceFlow:
+ return m, m.copilotDeviceFlow.CopyCode()
case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showClaudeOAuth2 && m.claudeOAuth2.State == claude.OAuthStateURL:
return m, tea.Sequence(
tea.SetClipboard(m.claudeOAuth2.URL),
@@ -202,6 +199,9 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
return m, m.copilotDeviceFlow.CopyCodeAndOpenURL()
}
selectedItem := m.modelList.SelectedModel()
+ if selectedItem == nil {
+ return m, nil
+ }
modelType := config.SelectedModelTypeLarge
if m.modelList.GetModelType() == SmallModelType {
@@ -310,6 +310,11 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
m.hyperDeviceFlow.SetWidth(m.width - 2)
return m, m.hyperDeviceFlow.Init()
case catwalk.InferenceProviderCopilot:
+ if token, ok := config.Get().ImportCopilot(); ok {
+ m.selectedModel = selectedItem
+ m.selectedModelType = modelType
+ return m, m.saveOauthTokenAndContinue(token, true)
+ }
m.showCopilotDeviceFlow = true
m.selectedModel = selectedItem
m.selectedModelType = modelType
@@ -337,28 +342,26 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
return m, m.modelList.SetModelType(LargeModelType)
}
case key.Matches(msg, m.keyMap.Close):
- if m.showHyperDeviceFlow {
+ switch {
+ case m.showHyperDeviceFlow:
if m.hyperDeviceFlow != nil {
m.hyperDeviceFlow.Cancel()
}
m.showHyperDeviceFlow = false
m.selectedModel = nil
- }
- if m.showCopilotDeviceFlow {
+ case m.showCopilotDeviceFlow:
if m.copilotDeviceFlow != nil {
m.copilotDeviceFlow.Cancel()
}
m.showCopilotDeviceFlow = false
m.selectedModel = nil
- }
- if m.showClaudeAuthMethodChooser {
+ case m.showClaudeAuthMethodChooser:
m.claudeAuthMethodChooser.SetDefaults()
m.showClaudeAuthMethodChooser = false
m.keyMap.isClaudeAuthChoiceHelp = false
m.keyMap.isClaudeOAuthHelp = false
return m, nil
- }
- if m.needsAPIKey {
+ case m.needsAPIKey:
if m.isAPIKeyValid {
return m, nil
}
@@ -369,37 +372,40 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
m.apiKeyValue = ""
m.apiKeyInput.Reset()
return m, nil
+ default:
+ return m, util.CmdHandler(dialogs.CloseDialogMsg{})
}
- return m, util.CmdHandler(dialogs.CloseDialogMsg{})
default:
- if m.showClaudeAuthMethodChooser {
+ switch {
+ case m.showClaudeAuthMethodChooser:
u, cmd := m.claudeAuthMethodChooser.Update(msg)
m.claudeAuthMethodChooser = u.(*claude.AuthMethodChooser)
return m, cmd
- } else if m.showClaudeOAuth2 {
+ case m.showClaudeOAuth2:
u, cmd := m.claudeOAuth2.Update(msg)
m.claudeOAuth2 = u.(*claude.OAuth2)
return m, cmd
- } else if m.needsAPIKey {
+ case m.needsAPIKey:
u, cmd := m.apiKeyInput.Update(msg)
m.apiKeyInput = u.(*APIKeyInput)
return m, cmd
- } else {
+ default:
u, cmd := m.modelList.Update(msg)
m.modelList = u
return m, cmd
}
}
case tea.PasteMsg:
- if m.showClaudeOAuth2 {
+ switch {
+ case m.showClaudeOAuth2:
u, cmd := m.claudeOAuth2.Update(msg)
m.claudeOAuth2 = u.(*claude.OAuth2)
return m, cmd
- } else if m.needsAPIKey {
+ case m.needsAPIKey:
u, cmd := m.apiKeyInput.Update(msg)
m.apiKeyInput = u.(*APIKeyInput)
return m, cmd
- } else {
+ default:
var cmd tea.Cmd
m.modelList, cmd = m.modelList.Update(msg)
return m, cmd
@@ -31,7 +31,9 @@ import (
"github.com/charmbracelet/crush/internal/tui/components/dialogs"
"github.com/charmbracelet/crush/internal/tui/components/dialogs/claude"
"github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
+ "github.com/charmbracelet/crush/internal/tui/components/dialogs/copilot"
"github.com/charmbracelet/crush/internal/tui/components/dialogs/filepicker"
+ "github.com/charmbracelet/crush/internal/tui/components/dialogs/hyper"
"github.com/charmbracelet/crush/internal/tui/components/dialogs/models"
"github.com/charmbracelet/crush/internal/tui/components/dialogs/reasoning"
"github.com/charmbracelet/crush/internal/tui/page"
@@ -335,7 +337,14 @@ func (p *chatPage) Update(msg tea.Msg) (util.Model, tea.Cmd) {
cmds = append(cmds, cmd)
return p, tea.Batch(cmds...)
- case claude.ValidationCompletedMsg, claude.AuthenticationCompleteMsg:
+ case claude.ValidationCompletedMsg,
+ claude.AuthenticationCompleteMsg,
+ hyper.DeviceFlowCompletedMsg,
+ hyper.DeviceAuthInitiatedMsg,
+ hyper.DeviceFlowErrorMsg,
+ copilot.DeviceAuthInitiatedMsg,
+ copilot.DeviceFlowErrorMsg,
+ copilot.DeviceFlowCompletedMsg:
if p.focusedPane == PanelTypeSplash {
u, cmd := p.splash.Update(msg)
p.splash = u.(splash.Splash)
@@ -604,8 +613,11 @@ func (p *chatPage) View() string {
pillsArea = pillsRow
}
- style := t.S().Base.MarginTop(1).PaddingLeft(3)
- pillsArea = style.Render(pillsArea)
+ pillsArea = t.S().Base.
+ MaxWidth(p.width).
+ MarginTop(1).
+ PaddingLeft(3).
+ Render(pillsArea)
}
if p.compact {
@@ -1050,7 +1062,8 @@ func (p *chatPage) Help() help.KeyMap {
fullList = append(fullList, []key.Binding{v})
}
case p.isOnboarding && p.splash.IsShowingClaudeOAuth2():
- if p.splash.IsClaudeOAuthURLState() {
+ switch {
+ case p.splash.IsClaudeOAuthURLState():
shortList = append(shortList,
key.NewBinding(
key.WithKeys("enter"),
@@ -1061,14 +1074,25 @@ func (p *chatPage) Help() help.KeyMap {
key.WithHelp("c", "copy url"),
),
)
- } else if p.splash.IsClaudeOAuthComplete() {
+ case p.splash.IsClaudeOAuthComplete():
shortList = append(shortList,
key.NewBinding(
key.WithKeys("enter"),
key.WithHelp("enter", "continue"),
),
)
- } else {
+ case p.splash.IsShowingHyperOAuth2() || p.splash.IsShowingCopilotOAuth2():
+ shortList = append(shortList,
+ key.NewBinding(
+ key.WithKeys("enter"),
+ key.WithHelp("enter", "copy url & open signup"),
+ ),
+ key.NewBinding(
+ key.WithKeys("c"),
+ key.WithHelp("c", "copy url"),
+ ),
+ )
+ default:
shortList = append(shortList,
key.NewBinding(
key.WithKeys("enter"),
@@ -10,7 +10,8 @@ const (
ArrowRightIcon string = "→"
CenterSpinnerIcon string = "⋯"
LoadingIcon string = "⟳"
- DocumentIcon string = "🖼"
+ ImageIcon string = "■"
+ TextIcon string = "☰"
ModelIcon string = "◇"
// Tool call icons
@@ -367,6 +367,17 @@
"type": "array",
"description": "Paths to files containing context information for the AI"
},
+ "skills_paths": {
+ "items": {
+ "type": "string",
+ "examples": [
+ "~/.config/crush/skills",
+ "./skills"
+ ]
+ },
+ "type": "array",
+ "description": "Paths to directories containing Agent Skills (folders with SKILL.md files)"
+ },
"tui": {
"$ref": "#/$defs/TUIOptions",
"description": "Terminal user interface options"