graphql: replace GitRef.isDefault by Repository.head (#1551)

Michael MurΓ© and Copilot created

isDefault was barely working

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

Change summary

api/graphql/graph/git.generated.go        |  51 ------------
api/graphql/graph/repository.generated.go | 101 +++++++++++++++++++++++++
api/graphql/graph/root.generated.go       |   2 
api/graphql/graph/root_.generated.go      |  22 ++--
api/graphql/models/gen_models.go          |   2 
api/graphql/resolvers/repo.go             |  12 ++
api/graphql/schema/git.graphql            |   2 
api/graphql/schema/repository.graphql     |   5 +
repository/browse.go                      |   5 
repository/gogit.go                       |  62 ++++++--------
repository/mock_repo.go                   |  35 +++++---
repository/repo.go                        |   6 +
repository/repo_testing.go                |  13 ++
13 files changed, 197 insertions(+), 121 deletions(-)

Detailed changes

api/graphql/graph/git.generated.go πŸ”—

@@ -2319,50 +2319,6 @@ func (ec *executionContext) fieldContext_GitRef_hash(_ context.Context, field gr
 	return fc, nil
 }
 
-func (ec *executionContext) _GitRef_isDefault(ctx context.Context, field graphql.CollectedField, obj *models.GitRef) (ret graphql.Marshaler) {
-	fc, err := ec.fieldContext_GitRef_isDefault(ctx, field)
-	if err != nil {
-		return graphql.Null
-	}
-	ctx = graphql.WithFieldContext(ctx, fc)
-	defer func() {
-		if r := recover(); r != nil {
-			ec.Error(ctx, ec.Recover(ctx, r))
-			ret = graphql.Null
-		}
-	}()
-	resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (any, error) {
-		ctx = rctx // use context from middleware stack in children
-		return obj.IsDefault, nil
-	})
-	if err != nil {
-		ec.Error(ctx, err)
-		return graphql.Null
-	}
-	if resTmp == nil {
-		if !graphql.HasFieldError(ctx, fc) {
-			ec.Errorf(ctx, "must not be null")
-		}
-		return graphql.Null
-	}
-	res := resTmp.(bool)
-	fc.Result = res
-	return ec.marshalNBoolean2bool(ctx, field.Selections, res)
-}
-
-func (ec *executionContext) fieldContext_GitRef_isDefault(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) {
-	fc = &graphql.FieldContext{
-		Object:     "GitRef",
-		Field:      field,
-		IsMethod:   false,
-		IsResolver: false,
-		Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) {
-			return nil, errors.New("field of type Boolean does not have child fields")
-		},
-	}
-	return fc, nil
-}
-
 func (ec *executionContext) _GitRefConnection_nodes(ctx context.Context, field graphql.CollectedField, obj *models.GitRefConnection) (ret graphql.Marshaler) {
 	fc, err := ec.fieldContext_GitRefConnection_nodes(ctx, field)
 	if err != nil {
@@ -2410,8 +2366,6 @@ func (ec *executionContext) fieldContext_GitRefConnection_nodes(_ context.Contex
 				return ec.fieldContext_GitRef_type(ctx, field)
 			case "hash":
 				return ec.fieldContext_GitRef_hash(ctx, field)
-			case "isDefault":
-				return ec.fieldContext_GitRef_isDefault(ctx, field)
 			}
 			return nil, fmt.Errorf("no field named %q was found under type GitRef", field.Name)
 		},
@@ -3414,11 +3368,6 @@ func (ec *executionContext) _GitRef(ctx context.Context, sel ast.SelectionSet, o
 			if out.Values[i] == graphql.Null {
 				out.Invalids++
 			}
-		case "isDefault":
-			out.Values[i] = ec._GitRef_isDefault(ctx, field, obj)
-			if out.Values[i] == graphql.Null {
-				out.Invalids++
-			}
 		default:
 			panic("unknown field " + strconv.Quote(field.Name))
 		}

api/graphql/graph/repository.generated.go πŸ”—

@@ -31,6 +31,7 @@ type RepositoryResolver interface {
 	Commits(ctx context.Context, obj *models.Repository, after *string, first *int, ref string, path *string, since *time.Time, until *time.Time) (*models.GitCommitConnection, error)
 	Commit(ctx context.Context, obj *models.Repository, hash string) (*models.GitCommitMeta, error)
 	LastCommits(ctx context.Context, obj *models.Repository, ref string, path *string, names []string) ([]*models.GitLastCommit, error)
+	Head(ctx context.Context, obj *models.Repository) (*models.GitCommitMeta, error)
 	ValidLabels(ctx context.Context, obj *models.Repository, after *string, before *string, first *int, last *int) (*models.LabelConnection, error)
 }
 
@@ -1655,6 +1656,69 @@ func (ec *executionContext) fieldContext_Repository_lastCommits(ctx context.Cont
 	return fc, nil
 }
 
+func (ec *executionContext) _Repository_head(ctx context.Context, field graphql.CollectedField, obj *models.Repository) (ret graphql.Marshaler) {
+	fc, err := ec.fieldContext_Repository_head(ctx, field)
+	if err != nil {
+		return graphql.Null
+	}
+	ctx = graphql.WithFieldContext(ctx, fc)
+	defer func() {
+		if r := recover(); r != nil {
+			ec.Error(ctx, ec.Recover(ctx, r))
+			ret = graphql.Null
+		}
+	}()
+	resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (any, error) {
+		ctx = rctx // use context from middleware stack in children
+		return ec.resolvers.Repository().Head(rctx, obj)
+	})
+	if err != nil {
+		ec.Error(ctx, err)
+		return graphql.Null
+	}
+	if resTmp == nil {
+		return graphql.Null
+	}
+	res := resTmp.(*models.GitCommitMeta)
+	fc.Result = res
+	return ec.marshalOGitCommit2αš–githubαš—comαš‹gitαš‘bugαš‹gitαš‘bugαš‹apiαš‹graphqlαš‹modelsᚐGitCommitMeta(ctx, field.Selections, res)
+}
+
+func (ec *executionContext) fieldContext_Repository_head(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) {
+	fc = &graphql.FieldContext{
+		Object:     "Repository",
+		Field:      field,
+		IsMethod:   true,
+		IsResolver: true,
+		Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) {
+			switch field.Name {
+			case "hash":
+				return ec.fieldContext_GitCommit_hash(ctx, field)
+			case "shortHash":
+				return ec.fieldContext_GitCommit_shortHash(ctx, field)
+			case "message":
+				return ec.fieldContext_GitCommit_message(ctx, field)
+			case "fullMessage":
+				return ec.fieldContext_GitCommit_fullMessage(ctx, field)
+			case "authorName":
+				return ec.fieldContext_GitCommit_authorName(ctx, field)
+			case "authorEmail":
+				return ec.fieldContext_GitCommit_authorEmail(ctx, field)
+			case "date":
+				return ec.fieldContext_GitCommit_date(ctx, field)
+			case "parents":
+				return ec.fieldContext_GitCommit_parents(ctx, field)
+			case "files":
+				return ec.fieldContext_GitCommit_files(ctx, field)
+			case "diff":
+				return ec.fieldContext_GitCommit_diff(ctx, field)
+			}
+			return nil, fmt.Errorf("no field named %q was found under type GitCommit", field.Name)
+		},
+	}
+	return fc, nil
+}
+
 func (ec *executionContext) _Repository_validLabels(ctx context.Context, field graphql.CollectedField, obj *models.Repository) (ret graphql.Marshaler) {
 	fc, err := ec.fieldContext_Repository_validLabels(ctx, field)
 	if err != nil {
@@ -1833,6 +1897,8 @@ func (ec *executionContext) fieldContext_RepositoryConnection_nodes(_ context.Co
 				return ec.fieldContext_Repository_commit(ctx, field)
 			case "lastCommits":
 				return ec.fieldContext_Repository_lastCommits(ctx, field)
+			case "head":
+				return ec.fieldContext_Repository_head(ctx, field)
 			case "validLabels":
 				return ec.fieldContext_Repository_validLabels(ctx, field)
 			}
@@ -2047,6 +2113,8 @@ func (ec *executionContext) fieldContext_RepositoryEdge_node(_ context.Context,
 				return ec.fieldContext_Repository_commit(ctx, field)
 			case "lastCommits":
 				return ec.fieldContext_Repository_lastCommits(ctx, field)
+			case "head":
+				return ec.fieldContext_Repository_head(ctx, field)
 			case "validLabels":
 				return ec.fieldContext_Repository_validLabels(ctx, field)
 			}
@@ -2492,6 +2560,39 @@ func (ec *executionContext) _Repository(ctx context.Context, sel ast.SelectionSe
 				continue
 			}
 
+			out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) })
+		case "head":
+			field := field
+
+			innerFunc := func(ctx context.Context, _ *graphql.FieldSet) (res graphql.Marshaler) {
+				defer func() {
+					if r := recover(); r != nil {
+						ec.Error(ctx, ec.Recover(ctx, r))
+					}
+				}()
+				res = ec._Repository_head(ctx, field, obj)
+				return res
+			}
+
+			if field.Deferrable != nil {
+				dfs, ok := deferred[field.Deferrable.Label]
+				di := 0
+				if ok {
+					dfs.AddField(field)
+					di = len(dfs.Values) - 1
+				} else {
+					dfs = graphql.NewFieldSet([]graphql.CollectedField{field})
+					deferred[field.Deferrable.Label] = dfs
+				}
+				dfs.Concurrently(di, func(ctx context.Context) graphql.Marshaler {
+					return innerFunc(ctx, dfs)
+				})
+
+				// don't run the out.Concurrently() call below
+				out.Values[i] = graphql.Null
+				continue
+			}
+
 			out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) })
 		case "validLabels":
 			field := field

api/graphql/graph/root.generated.go πŸ”—

@@ -1134,6 +1134,8 @@ func (ec *executionContext) fieldContext_Query_repository(ctx context.Context, f
 				return ec.fieldContext_Repository_commit(ctx, field)
 			case "lastCommits":
 				return ec.fieldContext_Repository_lastCommits(ctx, field)
+			case "head":
+				return ec.fieldContext_Repository_head(ctx, field)
 			case "validLabels":
 				return ec.fieldContext_Repository_validLabels(ctx, field)
 			}

api/graphql/graph/root_.generated.go πŸ”—

@@ -367,7 +367,6 @@ type ComplexityRoot struct {
 
 	GitRef struct {
 		Hash      func(childComplexity int) int
-		IsDefault func(childComplexity int) int
 		Name      func(childComplexity int) int
 		ShortName func(childComplexity int) int
 		Type      func(childComplexity int) int
@@ -480,6 +479,7 @@ type ComplexityRoot struct {
 		Bug           func(childComplexity int, prefix string) int
 		Commit        func(childComplexity int, hash string) int
 		Commits       func(childComplexity int, after *string, first *int, ref string, path *string, since *time.Time, until *time.Time) int
+		Head          func(childComplexity int) int
 		Identity      func(childComplexity int, prefix string) int
 		LastCommits   func(childComplexity int, ref string, path *string, names []string) int
 		Name          func(childComplexity int) int
@@ -1827,13 +1827,6 @@ func (e *executableSchema) Complexity(ctx context.Context, typeName, field strin
 
 		return e.complexity.GitRef.Hash(childComplexity), true
 
-	case "GitRef.isDefault":
-		if e.complexity.GitRef.IsDefault == nil {
-			break
-		}
-
-		return e.complexity.GitRef.IsDefault(childComplexity), true
-
 	case "GitRef.name":
 		if e.complexity.GitRef.Name == nil {
 			break
@@ -2367,6 +2360,13 @@ func (e *executableSchema) Complexity(ctx context.Context, typeName, field strin
 
 		return e.complexity.Repository.Commits(childComplexity, args["after"].(*string), args["first"].(*int), args["ref"].(string), args["path"].(*string), args["since"].(*time.Time), args["until"].(*time.Time)), true
 
+	case "Repository.head":
+		if e.complexity.Repository.Head == nil {
+			break
+		}
+
+		return e.complexity.Repository.Head(childComplexity), true
+
 	case "Repository.identity":
 		if e.complexity.Repository.Identity == nil {
 			break
@@ -3202,8 +3202,6 @@ type GitRef {
     type: GitRefType!
     """Commit hash the reference points to."""
     hash: String!
-    """True for the branch HEAD currently points to."""
-    isDefault: Boolean!
 }
 
 """An entry in a git tree (directory listing)."""
@@ -3591,6 +3589,10 @@ type OperationEdge {
     tree listing without blocking the initial tree fetch."""
     lastCommits(ref: String!, path: String, names: [String!]!): [GitLastCommit!]!
 
+    """The currently checked-out commit (branch, tag, hash ...) in the git repository.
+    Null if there is none (bare repo)."""
+    head: GitCommit
+
     """List of valid labels."""
     validLabels(
         """Returns the elements in the list that come after the specified cursor."""

api/graphql/models/gen_models.go πŸ”—

@@ -320,8 +320,6 @@ type GitRef struct {
 	Type GitRefType `json:"type"`
 	// Commit hash the reference points to.
 	Hash string `json:"hash"`
-	// True for the branch HEAD currently points to.
-	IsDefault bool `json:"isDefault"`
 }
 
 type GitRefConnection struct {

api/graphql/resolvers/repo.go πŸ”—

@@ -225,7 +225,6 @@ func (repoResolver) Refs(_ context.Context, obj *models.Repository, after *strin
 				ShortName: b.Name,
 				Type:      models.GitRefTypeBranch,
 				Hash:      string(b.Hash),
-				IsDefault: b.IsDefault,
 			})
 		}
 	}
@@ -425,3 +424,14 @@ func (repoResolver) LastCommits(_ context.Context, obj *models.Repository, ref s
 	}
 	return result, nil
 }
+
+func (repoResolver) Head(_ context.Context, obj *models.Repository) (*models.GitCommitMeta, error) {
+	meta, err := obj.Repo.BrowseRepo().Head()
+	if errors.Is(err, repository.ErrNotFound) {
+		return nil, nil
+	}
+	if err != nil {
+		return nil, err
+	}
+	return &models.GitCommitMeta{Repo: obj.Repo, CommitMeta: meta}, nil
+}

api/graphql/schema/git.graphql πŸ”—

@@ -8,8 +8,6 @@ type GitRef {
     type: GitRefType!
     """Commit hash the reference points to."""
     hash: String!
-    """True for the branch HEAD currently points to."""
-    isDefault: Boolean!
 }
 
 """An entry in a git tree (directory listing)."""

api/graphql/schema/repository.graphql πŸ”—

@@ -84,6 +84,11 @@ type Repository {
     tree listing without blocking the initial tree fetch."""
     lastCommits(ref: String!, path: String, names: [String!]!): [GitLastCommit!]!
 
+    """The commit pointed to by HEAD in the git repository.
+    Null if HEAD cannot be resolved to a commit, for example in an empty or unborn
+    repository, or if HEAD is missing or invalid."""
+    head: GitCommit
+
     """List of valid labels."""
     validLabels(
         """Returns the elements in the list that come after the specified cursor."""

repository/browse.go πŸ”—

@@ -146,9 +146,8 @@ type FileDiff struct {
 
 // BranchInfo describes a local branch returned by RepoBrowse.Branches.
 type BranchInfo struct {
-	Name      string
-	Hash      Hash // commit hash
-	IsDefault bool // true for the branch HEAD points to
+	Name string
+	Hash Hash // commit hash
 }
 
 // TagInfo describes a tag returned by RepoBrowse.Tags.

repository/gogit.go πŸ”—

@@ -1001,39 +1001,8 @@ func (repo *GoGitRepo) resolveRefToHash(ref string) (plumbing.Hash, error) {
 	return plumbing.ZeroHash, ErrNotFound
 }
 
-// defaultBranchName returns the short name of the default branch.
-func (repo *GoGitRepo) defaultBranchName() string {
-	repo.rMutex.Lock()
-	defer repo.rMutex.Unlock()
-
-	// refs/remotes/origin/HEAD is a symbolic ref set by git clone that points
-	// to the remote's default branch (e.g. refs/remotes/origin/main). It is
-	// the most reliable signal for "what does the upstream consider default".
-	ref, err := repo.r.Reference("refs/remotes/origin/HEAD", false)
-	if err == nil && ref.Type() == plumbing.SymbolicReference {
-		const prefix = "refs/remotes/origin/"
-		if target := ref.Target().String(); strings.HasPrefix(target, prefix) {
-			return strings.TrimPrefix(target, prefix)
-		}
-	}
-	// Fall back to well-known names for repos without a configured remote.
-	for _, name := range []string{"main", "master", "trunk", "develop"} {
-		_, err := repo.r.Reference(plumbing.NewBranchReferenceName(name), false)
-		if err == nil {
-			return name
-		}
-	}
-	return ""
-}
-
-// Branches returns all local branches. IsDefault marks the upstream's default
-// branch, determined in order:
-//  1. refs/remotes/origin/HEAD (set by git clone, reflects the server default)
-//  2. First match among: main, master, trunk, develop
-//  3. No branch marked if none of the above resolve
+// Branches returns all local branches (refs/heads/*).
 func (repo *GoGitRepo) Branches() ([]BranchInfo, error) {
-	defaultBranch := repo.defaultBranchName()
-
 	repo.rMutex.Lock()
 	defer repo.rMutex.Unlock()
 
@@ -1048,9 +1017,8 @@ func (repo *GoGitRepo) Branches() ([]BranchInfo, error) {
 			return nil
 		}
 		branches = append(branches, BranchInfo{
-			Name:      r.Name().Short(),
-			Hash:      Hash(r.Hash().String()),
-			IsDefault: r.Name().Short() == defaultBranch,
+			Name: r.Name().Short(),
+			Hash: Hash(r.Hash().String()),
 		})
 		return nil
 	})
@@ -1584,6 +1552,30 @@ func (repo *GoGitRepo) CommitFileDiff(hash Hash, filePath string) (FileDiff, err
 	return FileDiff{}, ErrNotFound
 }
 
+// Head returns the commit that HEAD currently points to.
+func (repo *GoGitRepo) Head() (CommitMeta, error) {
+	repo.rMutex.Lock()
+	defer repo.rMutex.Unlock()
+
+	ref, err := repo.r.Head()
+	if err == plumbing.ErrReferenceNotFound {
+		return CommitMeta{}, ErrNotFound
+	}
+	if err != nil {
+		return CommitMeta{}, err
+	}
+
+	c, err := repo.r.CommitObject(ref.Hash())
+	if err == plumbing.ErrObjectNotFound {
+		return CommitMeta{}, ErrNotFound
+	}
+	if err != nil {
+		return CommitMeta{}, err
+	}
+
+	return commitToMeta(c), nil
+}
+
 // buildDiffHunks converts a go-git FilePatch into DiffHunks with line numbers
 // and context grouping.
 func buildDiffHunks(fp fdiff.FilePatch) []DiffHunk {

repository/mock_repo.go πŸ”—

@@ -233,20 +233,18 @@ type commit struct {
 }
 
 type mockRepoDataBrowse struct {
-	blobs         map[Hash][]byte
-	trees         map[Hash]string
-	commits       map[Hash]commit
-	refs          map[string]Hash
-	defaultBranch string
+	blobs   map[Hash][]byte
+	trees   map[Hash]string
+	commits map[Hash]commit
+	refs    map[string]Hash
 }
 
 func newMockRepoDataBrowse() *mockRepoDataBrowse {
 	return &mockRepoDataBrowse{
-		blobs:         make(map[Hash][]byte),
-		trees:         make(map[Hash]string),
-		commits:       make(map[Hash]commit),
-		refs:          make(map[string]Hash),
-		defaultBranch: "main",
+		blobs:   make(map[Hash][]byte),
+		trees:   make(map[Hash]string),
+		commits: make(map[Hash]commit),
+		refs:    make(map[string]Hash),
 	}
 }
 
@@ -546,9 +544,8 @@ func (r *mockRepoDataBrowse) Branches() ([]BranchInfo, error) {
 			continue
 		}
 		branches = append(branches, BranchInfo{
-			Name:      name,
-			Hash:      hash,
-			IsDefault: name == r.defaultBranch,
+			Name: name,
+			Hash: hash,
 		})
 	}
 	return branches, nil
@@ -798,6 +795,18 @@ func (r *mockRepoDataBrowse) CommitFileDiff(hash Hash, filePath string) (FileDif
 	return fd, nil
 }
 
+func (r *mockRepoDataBrowse) Head() (CommitMeta, error) {
+	hash, ok := r.refs["HEAD"]
+	if !ok {
+		return CommitMeta{}, ErrNotFound
+	}
+	c, ok := r.commits[hash]
+	if !ok {
+		return CommitMeta{}, ErrNotFound
+	}
+	return mockCommitMeta(hash, c), nil
+}
+
 // mockDiffHunks produces a single DiffHunk using a prefix/suffix scan.
 func mockDiffHunks(old, new []byte) []DiffHunk {
 	oldLines := splitBlobLines(old)

repository/repo.go πŸ”—

@@ -220,7 +220,6 @@ type RepoClock interface {
 // refs/heads/<ref>, refs/tags/<ref>, full ref name, raw commit hash.
 type RepoBrowse interface {
 	// Branches returns all local branches (refs/heads/*).
-	// IsDefault marks the branch HEAD points to.
 	// All other ref namespaces β€” including git-bug's internal refs
 	// (refs/bugs/, refs/identities/, …) β€” are excluded.
 	Branches() ([]BranchInfo, error)
@@ -263,6 +262,11 @@ type RepoBrowse interface {
 	// identified by its hash. Diffs against the first parent only; the
 	// initial commit is diffed against the empty tree.
 	CommitFileDiff(hash Hash, filePath string) (FileDiff, error)
+
+	// Head returns the commit that HEAD currently points to.
+	// Returns ErrNotFound if HEAD cannot be resolved to a commit, including
+	// for an empty (unborn) repository.
+	Head() (CommitMeta, error)
 }
 
 // ClockLoader hold which logical clock need to exist for an entity and

repository/repo_testing.go πŸ”—

@@ -466,10 +466,7 @@ func RepoBrowseTest(t *testing.T, repo browsable) {
 		}
 
 		require.Equal(t, c3, byName["main"].Hash)
-		require.True(t, byName["main"].IsDefault)
-
 		require.Equal(t, c2, byName["feature"].Hash)
-		require.False(t, byName["feature"].IsDefault)
 	})
 
 	// ── Tags ──────────────────────────────────────────────────────────────────
@@ -771,4 +768,14 @@ func RepoBrowseTest(t *testing.T, repo browsable) {
 		_, err = repo.CommitFileDiff(randomHash(), "main.go")
 		require.ErrorIs(t, err, ErrNotFound)
 	})
+
+	// ── Head ──────────────────────────────────────────────────────────────────
+
+	t.Run("Head", func(t *testing.T) {
+		require.NoError(t, repo.UpdateRef("HEAD", c3))
+
+		meta, err := repo.Head()
+		require.NoError(t, err)
+		require.Equal(t, c3, meta.Hash)
+	})
 }