@@ -12,9 +12,11 @@ import (
"fmt"
"io"
"net/http"
+ "net/http/cookiejar"
"net/url"
"strconv"
"strings"
+ "sync"
"time"
)
@@ -33,18 +35,18 @@ type Client struct {
baseURL string
httpClient *http.Client
auth Auth
- // cached session from password login
- cachedCookieName string
- cachedCookieVal string
+ authMu sync.Mutex
}
// New creates a new SilverBullet client.
// The baseURL is normalized to remove trailing slashes.
func New(baseURL string, auth Auth) *Client {
+ jar, _ := cookiejar.New(nil)
return &Client{
baseURL: strings.TrimRight(baseURL, "/"),
httpClient: &http.Client{
Timeout: 6 * time.Hour, // long enough for lua_script with X-Timeout up to 21600s
+ Jar: jar,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
@@ -92,11 +94,8 @@ func (c *Client) ExecuteLua(ctx context.Context, script string, timeout int) (*L
req.Header.Set("Content-Type", "text/plain")
req.Header.Set("X-Timeout", strconv.Itoa(timeout))
- if err := c.setAuth(req); err != nil {
- return nil, fmt.Errorf("setting auth: %w", err)
- }
- resp, err := c.httpClient.Do(req)
+ resp, err := c.do(req)
if err != nil {
return nil, fmt.Errorf("executing lua: %w", err)
}
@@ -135,11 +134,7 @@ func (c *Client) Screenshot(ctx context.Context) ([]byte, error) {
return nil, fmt.Errorf("creating request: %w", err)
}
- if err := c.setAuth(req); err != nil {
- return nil, fmt.Errorf("setting auth: %w", err)
- }
-
- resp, err := c.httpClient.Do(req)
+ resp, err := c.do(req)
if err != nil {
return nil, fmt.Errorf("fetching screenshot: %w", err)
}
@@ -179,10 +174,6 @@ func (c *Client) ConsoleLogs(ctx context.Context, limit int, since int64) (*Logs
return nil, fmt.Errorf("creating request: %w", err)
}
- if err := c.setAuth(req); err != nil {
- return nil, fmt.Errorf("setting auth: %w", err)
- }
-
q := req.URL.Query()
q.Set("limit", strconv.Itoa(limit))
if since > 0 {
@@ -190,7 +181,7 @@ func (c *Client) ConsoleLogs(ctx context.Context, limit int, since int64) (*Logs
}
req.URL.RawQuery = q.Encode()
- resp, err := c.httpClient.Do(req)
+ resp, err := c.do(req)
if err != nil {
return nil, fmt.Errorf("fetching logs: %w", err)
}
@@ -218,9 +209,62 @@ func (c *Client) resolveURL(path string) (string, error) {
return url.JoinPath(c.baseURL, path)
}
+// do sends an authenticated request.
+// If a password session is rejected, it discards the cached cookie, logs in again,
+// and retries the request once.
+func (c *Client) do(req *http.Request) (*http.Response, error) {
+ if err := c.setAuth(req); err != nil {
+ return nil, fmt.Errorf("setting auth: %w", err)
+ }
+
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ if !c.usesPasswordAuth() || !isAuthFailure(resp.StatusCode) {
+ return resp, nil
+ }
+
+ _, _ = io.Copy(io.Discard, resp.Body)
+ _ = resp.Body.Close()
+
+ if req.Body != nil && req.GetBody == nil {
+ return nil, fmt.Errorf("authentication failed and request body cannot be replayed")
+ }
+ if err := c.expireSessionCookies(req.URL); err != nil {
+ return nil, fmt.Errorf("expiring session cookie: %w", err)
+ }
+
+ retryReq := req.Clone(req.Context())
+ retryReq.Header = req.Header.Clone()
+ retryReq.Header.Del("Cookie")
+ if req.GetBody != nil {
+ retryBody, err := req.GetBody()
+ if err != nil {
+ return nil, fmt.Errorf("recreating request body: %w", err)
+ }
+ retryReq.Body = retryBody
+ }
+
+ if err := c.setAuth(retryReq); err != nil {
+ return nil, fmt.Errorf("refreshing auth: %w", err)
+ }
+
+ return c.httpClient.Do(retryReq)
+}
+
+func (c *Client) usesPasswordAuth() bool {
+ return c.auth.User != "" && c.auth.Pass != ""
+}
+
+func isAuthFailure(status int) bool {
+ return status == http.StatusUnauthorized || (status >= 300 && status < 400)
+}
+
// setAuth sets authentication headers on the request.
// Bearer token goes on the Authorization header.
-// Password auth logs in via POST /.auth and sends the session cookie.
+// Password auth logs in via POST /.auth; the client's cookie jar sends the
+// resulting session cookie on subsequent requests until it expires.
// Both can coexist — they use different headers.
func (c *Client) setAuth(req *http.Request) error {
// Bearer token on Authorization header
@@ -229,7 +273,7 @@ func (c *Client) setAuth(req *http.Request) error {
}
// Password auth via session cookie
- if c.auth.User != "" && c.auth.Pass != "" {
+ if c.usesPasswordAuth() {
if err := c.ensureSessionCookie(req); err != nil {
return fmt.Errorf("session login: %w", err)
}
@@ -238,27 +282,29 @@ func (c *Client) setAuth(req *http.Request) error {
return nil
}
-// ensureSessionCookie logs in via POST /.auth if needed and sets the session cookie.
+// ensureSessionCookie logs in via POST /.auth if the cookie jar has no current
+// auth cookie. The jar applies normal cookie expiration rules.
func (c *Client) ensureSessionCookie(req *http.Request) error {
- // Use cached session if available
- if c.cachedCookieName != "" && c.cachedCookieVal != "" {
- req.AddCookie(&http.Cookie{
- Name: c.cachedCookieName,
- Value: c.cachedCookieVal,
- })
+ loginURL, err := c.resolveURL("/.auth")
+ if err != nil {
+ return fmt.Errorf("resolving login URL: %w", err)
+ }
+ loginCookieURL, err := url.Parse(loginURL)
+ if err != nil {
+ return fmt.Errorf("parsing login URL: %w", err)
+ }
+
+ c.authMu.Lock()
+ defer c.authMu.Unlock()
+
+ if c.hasAuthCookie(loginCookieURL) {
return nil
}
- // Log in to get session cookie
form := url.Values{}
form.Set("username", c.auth.User)
form.Set("password", c.auth.Pass)
- loginURL, err := c.resolveURL("/.auth")
- if err != nil {
- return fmt.Errorf("resolving login URL: %w", err)
- }
-
loginReq, err := http.NewRequestWithContext(req.Context(), http.MethodPost, loginURL, strings.NewReader(form.Encode()))
if err != nil {
return fmt.Errorf("creating login request: %w", err)
@@ -274,19 +320,43 @@ func (c *Client) ensureSessionCookie(req *http.Request) error {
if loginResp.StatusCode != http.StatusOK {
return fmt.Errorf("login failed with status %d", loginResp.StatusCode)
}
+ if !c.hasAuthCookie(loginCookieURL) {
+ return fmt.Errorf("login succeeded but no auth cookie returned")
+ }
- // Extract session cookie from Set-Cookie header
- for _, cookie := range loginResp.Cookies() {
- if strings.HasPrefix(cookie.Name, "auth_") {
- c.cachedCookieName = cookie.Name
- c.cachedCookieVal = cookie.Value
- req.AddCookie(&http.Cookie{
- Name: cookie.Name,
- Value: cookie.Value,
- })
- return nil
+ return nil
+}
+
+func (c *Client) hasAuthCookie(u *url.URL) bool {
+ if c.httpClient.Jar == nil {
+ return false
+ }
+ for _, cookie := range c.httpClient.Jar.Cookies(u) {
+ if strings.HasPrefix(cookie.Name, "auth_") && cookie.Value != "" {
+ return true
}
}
+ return false
+}
+
+func (c *Client) expireSessionCookies(u *url.URL) error {
+ if c.httpClient.Jar == nil {
+ return nil
+ }
+
+ c.authMu.Lock()
+ defer c.authMu.Unlock()
- return fmt.Errorf("login succeeded but no auth cookie returned")
+ cookies := c.httpClient.Jar.Cookies(u)
+ for _, cookie := range cookies {
+ if strings.HasPrefix(cookie.Name, "auth_") {
+ c.httpClient.Jar.SetCookies(u, []*http.Cookie{{
+ Name: cookie.Name,
+ Value: "",
+ Path: "/",
+ MaxAge: -1,
+ }})
+ }
+ }
+ return nil
}
@@ -10,7 +10,9 @@ import (
"fmt"
"net/http"
"net/http/httptest"
+ "sync"
"testing"
+ "time"
)
func newTestServer(mux *http.ServeMux) *httptest.Server {
@@ -312,6 +314,131 @@ func TestBothAuth(t *testing.T) {
}
}
+func TestPasswordAuthReloginsAfterCookieExpires(t *testing.T) {
+ mux := http.NewServeMux()
+ var mu sync.Mutex
+ loginCount := 0
+ currentToken := ""
+
+ mux.HandleFunc("/.auth", func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ w.WriteHeader(http.StatusMethodNotAllowed)
+ return
+ }
+ if r.FormValue("username") != "testuser" || r.FormValue("password") != "testpass" {
+ w.WriteHeader(http.StatusUnauthorized)
+ return
+ }
+
+ mu.Lock()
+ loginCount++
+ currentToken = fmt.Sprintf("mock-jwt-token-%d", loginCount)
+ token := currentToken
+ mu.Unlock()
+
+ http.SetCookie(w, &http.Cookie{
+ Name: "auth_session",
+ Value: token,
+ MaxAge: 1,
+ })
+ w.WriteHeader(http.StatusOK)
+ })
+ mux.HandleFunc("/.runtime/logs", func(w http.ResponseWriter, r *http.Request) {
+ cookie, err := r.Cookie("auth_session")
+ if err != nil {
+ t.Error("expected auth_session cookie, got none")
+ w.WriteHeader(http.StatusUnauthorized)
+ return
+ }
+
+ mu.Lock()
+ expectedToken := currentToken
+ mu.Unlock()
+ if cookie.Value != expectedToken {
+ t.Errorf("expected cookie value %q, got %q", expectedToken, cookie.Value)
+ }
+ _ = json.NewEncoder(w).Encode(LogsResult{Logs: []LogEntry{}})
+ })
+
+ srv := newTestServer(mux)
+ defer srv.Close()
+
+ client := testClientWithBasicAuth(srv.URL)
+ if _, err := client.ConsoleLogs(context.Background(), 1, 0); err != nil {
+ t.Fatalf("first ConsoleLogs failed: %v", err)
+ }
+ time.Sleep(1100 * time.Millisecond)
+ if _, err := client.ConsoleLogs(context.Background(), 1, 0); err != nil {
+ t.Fatalf("second ConsoleLogs failed: %v", err)
+ }
+
+ mu.Lock()
+ defer mu.Unlock()
+ if loginCount != 2 {
+ t.Fatalf("expected 2 logins after cookie expiry, got %d", loginCount)
+ }
+}
+
+func TestPasswordAuthRefreshesRejectedSession(t *testing.T) {
+ mux := http.NewServeMux()
+ var mu sync.Mutex
+ loginCount := 0
+
+ mux.HandleFunc("/.auth", func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ w.WriteHeader(http.StatusMethodNotAllowed)
+ return
+ }
+ if r.FormValue("username") != "testuser" || r.FormValue("password") != "testpass" {
+ w.WriteHeader(http.StatusUnauthorized)
+ return
+ }
+
+ mu.Lock()
+ loginCount++
+ token := fmt.Sprintf("mock-jwt-token-%d", loginCount)
+ mu.Unlock()
+
+ http.SetCookie(w, &http.Cookie{
+ Name: "auth_session",
+ Value: token,
+ })
+ w.WriteHeader(http.StatusOK)
+ })
+ mux.HandleFunc("/.runtime/lua_script", func(w http.ResponseWriter, r *http.Request) {
+ cookie, err := r.Cookie("auth_session")
+ if err != nil {
+ t.Error("expected auth_session cookie, got none")
+ w.WriteHeader(http.StatusUnauthorized)
+ return
+ }
+ if cookie.Value == "mock-jwt-token-1" {
+ w.WriteHeader(http.StatusUnauthorized)
+ return
+ }
+ if cookie.Value != "mock-jwt-token-2" {
+ t.Errorf("expected refreshed cookie value, got %q", cookie.Value)
+ }
+
+ result := LuaResult{Result: json.RawMessage(`"ok"`)}
+ _ = json.NewEncoder(w).Encode(result)
+ })
+
+ srv := newTestServer(mux)
+ defer srv.Close()
+
+ client := testClientWithBasicAuth(srv.URL)
+ if _, err := client.ExecuteLua(context.Background(), "return 'ok'", 30); err != nil {
+ t.Fatalf("ExecuteLua failed: %v", err)
+ }
+
+ mu.Lock()
+ defer mu.Unlock()
+ if loginCount != 2 {
+ t.Fatalf("expected 2 logins after rejected session, got %d", loginCount)
+ }
+}
+
func TestTrailingSlashURL(t *testing.T) {
// Verify that trailing slashes in the base URL are handled correctly
client := New("http://localhost:3000/", Auth{})