From 2958347730de4324ccc688274fb6a756718da372 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20B=C3=B6hm?= Date: Mon, 18 Mar 2024 17:23:49 +0100 Subject: [PATCH] Fix type issues --- varats/varats/tools/driver_feature.py | 30 +++++++++++++++------------ 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/varats/varats/tools/driver_feature.py b/varats/varats/tools/driver_feature.py index 5191fc77c..fc3403fa2 100644 --- a/varats/varats/tools/driver_feature.py +++ b/varats/varats/tools/driver_feature.py @@ -107,7 +107,7 @@ def __prompt_location( feature_name: str, commit_hash: FullCommitHash, old_location: tp.Optional[Location] = None -) -> tp.Optional[Location]: +) -> Location: parse_location: tp.Callable[[str], tp.Optional[Location]] if old_location is not None: parse_location = partial( @@ -116,9 +116,12 @@ def __prompt_location( else: parse_location = Location.parse_string - return click.prompt( - f"Enter location for feature {feature_name} @ {commit_hash.short_hash}", - value_proc=parse_location + return tp.cast( + Location, + click.prompt( + f"Enter location for feature {feature_name} @ {commit_hash.short_hash}", + value_proc=parse_location + ) ) @@ -126,12 +129,13 @@ def __get_location_content(commit: Commit, location: Location) -> tp.Optional[str]: assert location.start_line == location.end_line, \ "Multiline locations are not supported yet." - lines = tp.cast(Blob, commit.tree[location.file]).data.splitlines() + lines: tp.List[bytes] = tp.cast(Blob, commit.tree[location.file + ]).data.splitlines() if len(lines) < location.start_line: return None - line = lines[location.start_line - 1].decode("utf-8") + line: str = lines[location.start_line - 1].decode("utf-8") if len(line) <= location.end_col: return None @@ -157,7 +161,7 @@ def main() -> None: required=False ) def __annotate( - project: str, revision: tp.Optional[str], outfile: tp.IO + project: str, revision: tp.Optional[str], outfile: tp.TextIO ) -> None: initialize_cli_tool() @@ -187,9 +191,9 @@ def __annotate( last_annotations[feature_name] = FeatureAnnotation( feature_name, location, commit_hash ) - last_annotation_targets[feature_name] = __get_location_content( - first_commit, location - ) + target = __get_location_content(first_commit, location) + assert target is not None, "Target must not be None" + last_annotation_targets[feature_name] = target tracked_features[feature_name] = [] LOG.debug( f"Tracking {feature_name} @ {location}: {last_annotation_targets[feature_name]}" @@ -221,9 +225,9 @@ def __annotate( last_annotations[feature] = FeatureAnnotation( feature, new_location, commit_hash ) - last_annotation_targets[feature] = __get_location_content( - commit, new_location - ) + new_target = __get_location_content(commit, new_location) + assert new_target is not None, "Target must not be None" + last_annotation_targets[feature] = new_target LOG.debug( f"Tracking {feature} @ {new_location}: {last_annotation_targets[feature]}" )