Source code for project_release.git

"""Git related code."""
import enum
import logging
from typing import List
from typing import Optional

import git
from pydantic import Field
from pydantic import field_validator
from pydantic import validate_call

from ._pydantic import Listable
from ._pydantic import UseDefaultValueModel

logger = logging.getLogger(__name__)


[docs] class GitBanchConfig(UseDefaultValueModel): """Git branch configuration.""" development: List[str] = [] # noqa: RUF012 release: List[str] = [] # noqa: RUF012 @field_validator("development", "release", mode="before") @classmethod @validate_call def _validate_branches(cls, value: Optional[Listable[str]]) -> List[str]: if value is None: return [] if isinstance(value, list): return value return [value]
[docs] class GitCommitConfig(UseDefaultValueModel): """Git commit configuration.""" message: str = "bump: version %(version)s" sign_off: bool = Field(default=False, alias="sign-off") gpg_sign: bool = Field(default=False, alias="gpg-sign")
[docs] class GitTagConfig(UseDefaultValueModel): """Git tag configuration.""" fromat: str = Field(default="%(version)s", alias="format") message: str = "version %(version)s" annotate: bool = True gpg_sign: bool = Field(default=False, alias="gpg-sign")
[docs] class GitConfig(UseDefaultValueModel): """Git configuration.""" branch: GitBanchConfig = GitBanchConfig() commit: GitCommitConfig = GitCommitConfig() tag: GitTagConfig = GitTagConfig()
[docs] def current_repo() -> git.Repo: """Return the current git repository. Returns ------- git.repo.base.Repo The current git repository. Raises ------ SystemExit If the git repository is invalid. """ try: repo = git.Repo(search_parent_directories=True) except git.InvalidGitRepositoryError as exc: raise SystemExit("Not in a git repository") from exc return repo
[docs] def local_branch_names(repo: git.Repo) -> List[str]: """Return the local branches. Parameters ---------- repo The current git repository. Returns ------- list of str The local branch names. """ return [branch.name for branch in repo.heads]
[docs] def remote_branch_names(remote: Optional[git.Remote]) -> List[str]: """Return the remote branches. Parameters ---------- remote or None The selected remote. Returns ------- list of str The remote branch names. """ if remote is None: return [] return [ ref.name.replace(f"{remote}/", "") for ref in remote.refs if not ref.name.endswith("HEAD") ]
[docs] def repo_branch_names(repo: git.Repo, remote: Optional[git.Remote]) -> List[str]: """Return the local and remote branches. Parameters ---------- repo The current git repository. remote or None The selected remote. Returns ------- list of str The local and remote branches. """ return list({*local_branch_names(repo), *remote_branch_names(remote)})
[docs] def current_branch_name(repo: git.Repo) -> Optional[str]: """Return the current branch name. Parameters ---------- repo The current git repository. Returns ------- str or None The current branch name or None if detached. """ try: return repo.active_branch.name except TypeError: return None
[docs] class RefPosition(enum.Enum): """The position of a ref in comparison to a base ref.""" AHEAD = enum.auto() """If the ref is ahead of the base ref.""" EQUAL = enum.auto() """If the ref and the base ref are the same.""" BEHIND = enum.auto() """If the ref is behind of the base ref.""" UNRELATED = enum.auto() """If the ref and the base ref does not share the same history."""
[docs] def compare_ref( ref: git.SymbolicReference, base_ref: git.SymbolicReference ) -> RefPosition: """Compare a ref to a base ref and return the relative position. Parameters ---------- ref The ref to be compared. base_ref The base ref of the comparison. Returns ------- RefPosition The relative position of the ref compared to the base ref. """ repo = ref.repo ahead = sum(1 for x in repo.iter_commits(f"{base_ref}..{ref}")) behind = sum(1 for x in repo.iter_commits(f"{ref}..{base_ref}")) if ahead > 0 and behind == 0: return RefPosition.AHEAD if ahead == 0 and behind == 0: return RefPosition.EQUAL if ahead == 0 and behind > 0: return RefPosition.BEHIND return RefPosition.UNRELATED
[docs] def update_branch(branch: git.Head, remote: git.Remote, force_update: bool) -> None: """Update a local git branch. Parameters ---------- branch The selected branch. remote The selected remote. force_update Allow forced branch update. Raises ------ SystemExit If the tracking branch is invalid or if the branch cannot be updated. """ if branch.name not in remote_branch_names(remote): logger.debug("The %s branch does not yet exist on the remote", branch) return expected_tracking_branch = remote.refs[branch.name] tracking_branch = branch.tracking_branch() if tracking_branch is None: logger.debug( "Configuring the %s branch to track %s", branch, expected_tracking_branch ) branch.set_tracking_branch(expected_tracking_branch) elif tracking_branch != expected_tracking_branch: if not force_update: raise SystemExit( f"Invalid tracking branch: {branch} is tracking " f"{tracking_branch} instead of {expected_tracking_branch}" ) logger.warning( "Configuring the %s branch to track %s", branch, expected_tracking_branch ) branch.set_tracking_branch(expected_tracking_branch) tracking_branch = expected_tracking_branch position = compare_ref(branch, tracking_branch) if position == RefPosition.EQUAL: logger.debug("The %s branch is up to date with %s", branch, tracking_branch) elif position == RefPosition.AHEAD: logger.debug("The %s branch is ahead of %s", branch, tracking_branch) elif position == RefPosition.BEHIND: logger.info("Updating the %s branch to match %s", branch, tracking_branch) branch.reference = tracking_branch.commit elif position == RefPosition.UNRELATED: if not force_update: raise SystemExit( f"Cannot update branch: {branch} and {tracking_branch}" "does not share the same history." ) logger.warning("Forcing the %s branch to match %s", branch, tracking_branch) branch.reference = tracking_branch.commit
[docs] def create_branch( branch_name: str, default_ref: git.SymbolicReference, remote: Optional[git.Remote] ) -> git.Head: """Create a local branch. Parameters ---------- branch_name The selected branch name. default_ref Default reference of the new branch. remote or None The selected remote. Returns ------- git.refs.head.Head The newly created branch. """ if remote and branch_name in remote_branch_names(remote): remote_branch = remote.refs[branch_name] new_branch = remote_branch.repo.create_head(branch_name, remote_branch) new_branch.set_tracking_branch(remote_branch) else: new_branch = default_ref.repo.create_head(branch_name, default_ref) return new_branch