Skip to content

Commit

Permalink
fix coverage gaps
Browse files Browse the repository at this point in the history
  • Loading branch information
sigma67 committed Jul 29, 2024
1 parent 7e9d4f0 commit 03894fe
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 29 deletions.
2 changes: 1 addition & 1 deletion src/towncrier/_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_version(package_dir: str, package: str) -> str:
Try to extract the version from the distribution version metadata that matches
`package`, then fall back to looking for the package in `package_dir`.
"""
version: str
version: str | None

# First try to get the version from the package metadata.
if version := _get_metadata_version(package):
Expand Down
59 changes: 31 additions & 28 deletions src/towncrier/test/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,34 +44,37 @@ def test_tuple(self):
version = get_version(temp, "mytestproja")
self.assertEqual(version, "1.3.12")

def test_incremental(self):
"""
An incremental-like Version __version__ is picked up.
"""
temp = self.mktemp()
os.makedirs(temp)
os.makedirs(os.path.join(temp, "mytestprojinc"))

with open(os.path.join(temp, "mytestprojinc", "__init__.py"), "w") as f:
f.write(
"""
class Version:
'''
This is emulating a Version object from incremental.
'''
def __init__(self, *version_parts):
self.version = version_parts
def base(self):
return '.'.join(map(str, self.version))
__version__ = Version(1, 3, 12, "rc1")
"""
)
version = get_version(temp, "mytestprojinc")
self.assertEqual(version, "1.3.12rc1")
def test_incremental(self):
"""
An incremental-like Version __version__ is picked up.
"""
temp = self.mktemp()
os.makedirs(temp)
os.makedirs(os.path.join(temp, "mytestprojinc"))

with open(os.path.join(temp, "mytestprojinc", "__init__.py"), "w") as f:
f.write(
"""
class Version:
'''
This is emulating a Version object from incremental.
'''
def __init__(self, *version_parts):
self.version = version_parts
def base(self):
return '.'.join(map(str, self.version))
__version__ = Version(1, 3, 12, "rc1")
"""
)

version = get_version(temp, "mytestprojinc")
self.assertEqual(version, "1.3.12rc1")

project = get_project_name(temp, "mytestprojinc")
self.assertEqual(project, "Mytestprojinc")

def test_not_incremental(self):
"""
Expand Down

0 comments on commit 03894fe

Please sign in to comment.