diff --git a/beast2xml/beast2.py b/beast2xml/beast2.py index 3b815bb..b2f2a0a 100644 --- a/beast2xml/beast2.py +++ b/beast2xml/beast2.py @@ -59,38 +59,39 @@ class BEAST2XML(object): date time unit. """ + TRACELOG_SUFFIX = ".log" TREELOG_SUFFIX = ".trees" _rate_change_to_param_dict = { - 'birthRateChangeTimes': 'reproductiveNumber', - 'deathRateChangeTimes': 'becomeUninfectiousRate', - 'samplingRateChangeTimes': 'samplingProportion' + "birthRateChangeTimes": "reproductiveNumber", + "deathRateChangeTimes": "becomeUninfectiousRate", + "samplingRateChangeTimes": "samplingProportion", } _distribution_args = { - 'Uniform': ['lower', 'upper', 'offset'], - 'LogNormal': ['meanInRealSpace', 'M', 'S', 'offset'], - 'Beta': ['alpha', 'beta', 'offset'], - 'Gamma': ['alpha', 'beta', 'offset'], - 'InverseGamma': ['alpha', 'beta', 'offset'], - 'LaplaceDistribution': ['mu', 'scale', 'offset'], - 'Exponential': ['mean', 'offset'], - 'Normal': ['mean', 'sigma', 'offset'], - 'WeibullDistribution': ['shape', 'scale', 'meanOne', 'offset'], - 'Poisson': ["lambda", 'offset'] + "Uniform": ["lower", "upper", "offset"], + "LogNormal": ["meanInRealSpace", "M", "S", "offset"], + "Beta": ["alpha", "beta", "offset"], + "Gamma": ["alpha", "beta", "offset"], + "InverseGamma": ["alpha", "beta", "offset"], + "LaplaceDistribution": ["mu", "scale", "offset"], + "Exponential": ["mean", "offset"], + "Normal": ["mean", "sigma", "offset"], + "WeibullDistribution": ["shape", "scale", "meanOne", "offset"], + "Poisson": ["lambda", "offset"], } def __init__( - self, - template=None, - clock_model="strict", - sequence_id_date_regex=None, - sequence_id_age_regex=None, - sequence_id_regex_must_match=True, - date_unit="year", + self, + template=None, + clock_model="strict", + sequence_id_date_regex=None, + sequence_id_age_regex=None, + sequence_id_regex_must_match=True, + date_unit="year", ): if template is None: self._tree = ET.parse( - files('beast2xml').joinpath(f'templates/{clock_model}.xml') + files("beast2xml").joinpath(f"templates/{clock_model}.xml") ) else: self._tree = ET.parse(template) @@ -130,24 +131,26 @@ def find_elements(tree): """ result = {} root = tree.getroot() - for tag in ("data", - "run", - "./run/state/tree/trait", - "./run/logger[@id='tracelog']", - "./run/logger[@id='treelog.t:", - "./run/logger[@id='screenlog']"): + for tag in ( + "data", + "run", + "./run/state/tree/trait", + "./run/logger[@id='tracelog']", + "./run/logger[@id='treelog.t:", + "./run/logger[@id='screenlog']", + ): if tag == "./run/logger[@id='treelog.t:": tag = tag + data_id + "']" element = root.find(tag) if element is None: - raise ValueError('Could not find %r tag in XML template' % tag) - if tag == 'data': - data_id = element.get('id') + raise ValueError("Could not find %r tag in XML template" % tag) + if tag == "data": + data_id = element.get("id") result[tag] = element return result - def add_ages(self, age_data, seperator='\t', age_column='year_decimal'): + def add_ages(self, age_data, seperator="\t", age_column="year_decimal"): """ Add age data. @@ -164,17 +167,19 @@ def add_ages(self, age_data, seperator='\t', age_column='year_decimal'): if isinstance(age_data, str): age_data = pd.read_csv(age_data, sep=seperator) if isinstance(age_data, pd.DataFrame): - if 'id' in age_data.columns: - age_data = age_data.set_index('id') - elif 'strain' in age_data.columns: - age_data = age_data.set_index('strain') + if "id" in age_data.columns: + age_data = age_data.set_index("id") + elif "strain" in age_data.columns: + age_data = age_data.set_index("strain") else: raise ValueError("An age_data column must be id or strain") age_data = age_data[age_column] if isinstance(age_data, pd.Series): age_data = age_data.to_dict() if not isinstance(age_data, dict): - raise ValueError('age_data must be a C{dict} a C{pd.DataFrame}, a C{pd.Series} or a path to tsv/csv.') + raise ValueError( + "age_data must be a C{dict} a C{pd.DataFrame}, a C{pd.Series} or a path to tsv/csv." + ) self._age_by_full_id.update(age_data) age_data = {key.split()[0]: value for key, value in age_data.items()} self._age_by_short_id.update(age_data) @@ -257,7 +262,7 @@ def add_sequence(self, sequence, age=None): if age is None: if self._sequence_id_regex_must_match and ( - self._sequence_id_date_regex or self._sequence_id_age_regex + self._sequence_id_date_regex or self._sequence_id_age_regex ): raise ValueError( "No sequence date or age could be found in %r " @@ -279,11 +284,19 @@ def add_sequences(self, sequences): for sequence in sequences: self.add_sequence(sequence) - def _to_xml_tree(self, chain_length=None, default_age=0.0, - date_direction=None, log_file_basename=None, - trace_log_every=None, tree_log_every=None, screen_log_every=None, - store_state_every=None, - transform_func=None, mimic_beauti=False): + def _to_xml_tree( + self, + chain_length=None, + default_age=0.0, + date_direction=None, + log_file_basename=None, + trace_log_every=None, + tree_log_every=None, + screen_log_every=None, + store_state_every=None, + transform_func=None, + mimic_beauti=False, + ): """ Generate xml.etree.ElementTree for running on BEAST. @@ -336,42 +349,61 @@ def _to_xml_tree(self, chain_length=None, default_age=0.0, elements = self.find_elements(self._tree) # Get data element_path - data = elements['data'] - data_id = data.get('id') + data = elements["data"] + data_id = data.get("id") tree_logger_key = "./run/logger[@id='treelog.t:" + data_id + "']" # Delete any existing children of the data node. delete_child_nodes(data) - trait = elements['./run/state/tree/trait'] + trait = elements["./run/state/tree/trait"] if not isinstance(default_age, (float, int)): - raise TypeError('The default age must be an integer or float.') + raise TypeError("The default age must be an integer or float.") age_by_short_id = deepcopy(self._age_by_short_id) # Add in all sequences. - for sequence in sorted(self._sequences): # Sorting adds the sequences alphabetically like in BEAUti. + for sequence in sorted( + self._sequences + ): # Sorting adds the sequences alphabetically like in BEAUti. seq_id = sequence.id short_id = seq_id.split()[0] if seq_id not in age_by_short_id: age_by_short_id[short_id] = default_age - ET.SubElement(data, 'sequence', id='seq_' + short_id, spec="Sequence", taxon=short_id, - totalcount='4', value=sequence.sequence) + ET.SubElement( + data, + "sequence", + id="seq_" + short_id, + spec="Sequence", + taxon=short_id, + totalcount="4", + value=sequence.sequence, + ) - trait_order = [sequence.id.split()[0] for sequence in self._sequences] # ensures order is the same as BEAUti's. - trait_text = [short_id + '=' + str(age_by_short_id[short_id]) for short_id in trait_order] + trait_order = [ + sequence.id.split()[0] for sequence in self._sequences + ] # ensures order is the same as BEAUti's. + trait_text = [ + short_id + "=" + str(age_by_short_id[short_id]) for short_id in trait_order + ] if date_direction is None: - trait.set('value', ','.join(trait_text)) # Replaces old age info with new age info + trait.set( + "value", ",".join(trait_text) + ) # Replaces old age info with new age info if trait.get("traitname") is None: - raise ValueError('No traitname attribute in dateTrait element_path of template xml.' + - ' This can be set through date_direction argument with the options ' + - '"backward", "forward" or "date".') + raise ValueError( + "No traitname attribute in dateTrait element_path of template xml." + + " This can be set through date_direction argument with the options " + + '"backward", "forward" or "date".' + ) else: - if date_direction not in ['backward', 'forward', 'date']: - raise ValueError('If supplied date_direction must be either "backward", "forward" or "date".') - trait.set('value', '') # Removes old age info - trait.text = ',\n'.join(trait_text) + '\n' # Adds new age info - if date_direction == 'date': + if date_direction not in ["backward", "forward", "date"]: + raise ValueError( + 'If supplied date_direction must be either "backward", "forward" or "date".' + ) + trait.set("value", "") # Removes old age info + trait.text = ",\n".join(trait_text) + "\n" # Adds new age info + if date_direction == "date": trait.set("traitname", date_direction) else: trait.set("traitname", "date-" + date_direction) @@ -392,7 +424,7 @@ def _to_xml_tree(self, chain_length=None, default_age=0.0, logger.set("fileName", log_file_basename + self.TRACELOG_SUFFIX) # Tree log. logger = elements[tree_logger_key] - logger.set('fileName', log_file_basename + self.TREELOG_SUFFIX) + logger.set("fileName", log_file_basename + self.TREELOG_SUFFIX) if trace_log_every is not None: logger = elements["./run/logger[@id='tracelog']"] @@ -400,28 +432,30 @@ def _to_xml_tree(self, chain_length=None, default_age=0.0, if tree_log_every is not None: logger = elements[tree_logger_key] - logger.set('logEvery', str(tree_log_every)) + logger.set("logEvery", str(tree_log_every)) if screen_log_every is not None: logger = elements["./run/logger[@id='screenlog']"] logger.set("logEvery", str(screen_log_every)) tree = self._tree if transform_func is None else transform_func(self._tree) - ET.indent(tree, '\t') + ET.indent(tree, "\t") return tree - def to_string(self, - chain_length=None, - default_age=0.0, - date_direction=None, - log_file_basename=None, - trace_log_every=None, - tree_log_every=None, - screen_log_every=None, - store_state_every=None, - transform_func=None, - mimic_beauti=False): - """ Generate str version of xml.etree.ElementTree for running on BEAST. + def to_string( + self, + chain_length=None, + default_age=0.0, + date_direction=None, + log_file_basename=None, + trace_log_every=None, + tree_log_every=None, + screen_log_every=None, + store_state_every=None, + transform_func=None, + mimic_beauti=False, + ): + """Generate str version of xml.etree.ElementTree for running on BEAST. Parameters ---------- @@ -464,28 +498,37 @@ def to_string(self, tree: str String representation of xml.etree.ElementTree for running on BEAST """ - tree = self._to_xml_tree(chain_length=chain_length, default_age=default_age, - date_direction=date_direction, log_file_basename=log_file_basename, - trace_log_every=trace_log_every, tree_log_every=tree_log_every, - screen_log_every=screen_log_every, store_state_every=store_state_every, - transform_func=transform_func, mimic_beauti=mimic_beauti) + tree = self._to_xml_tree( + chain_length=chain_length, + default_age=default_age, + date_direction=date_direction, + log_file_basename=log_file_basename, + trace_log_every=trace_log_every, + tree_log_every=tree_log_every, + screen_log_every=screen_log_every, + store_state_every=store_state_every, + transform_func=transform_func, + mimic_beauti=mimic_beauti, + ) stream = six.StringIO() tree.write(stream, "unicode" if six.PY3 else "utf-8", xml_declaration=True) return stream.getvalue() - def to_xml(self, - path, - chain_length=None, - default_age=0.0, - date_direction=None, - log_file_basename=None, - trace_log_every=None, - tree_log_every=None, - screen_log_every=None, - store_state_every=None, - transform_func=None, - mimic_beauti=False): + def to_xml( + self, + path, + chain_length=None, + default_age=0.0, + date_direction=None, + log_file_basename=None, + trace_log_every=None, + tree_log_every=None, + screen_log_every=None, + store_state_every=None, + transform_func=None, + mimic_beauti=False, + ): """ Generate xml.etree.ElementTree for running on BEAST and write to xml file. @@ -533,34 +576,54 @@ def to_xml(self, """ if not isinstance(path, str): - raise TypeError('filename must be a string.') - tree = self._to_xml_tree(chain_length=chain_length, default_age=default_age, - date_direction=date_direction, log_file_basename=log_file_basename, - trace_log_every=trace_log_every, tree_log_every=tree_log_every, - screen_log_every=screen_log_every, store_state_every=store_state_every, - transform_func=transform_func, mimic_beauti=mimic_beauti) - tree.write(path, 'unicode' if six.PY3 else 'utf-8', xml_declaration=True) - - def _search_for_parameter_in_element(self, element_path, parameter, wild_card_ending): + raise TypeError("filename must be a string.") + tree = self._to_xml_tree( + chain_length=chain_length, + default_age=default_age, + date_direction=date_direction, + log_file_basename=log_file_basename, + trace_log_every=trace_log_every, + tree_log_every=tree_log_every, + screen_log_every=screen_log_every, + store_state_every=store_state_every, + transform_func=transform_func, + mimic_beauti=mimic_beauti, + ) + tree.write(path, "unicode" if six.PY3 else "utf-8", xml_declaration=True) + + def _search_for_parameter_in_element( + self, element_path, parameter, wild_card_ending + ): if wild_card_ending: - parameter_nodes = [potential_parameter_node - for potential_parameter_node in self._tree.findall(element_path) - if potential_parameter_node.attrib['id'].startswith(parameter)] + parameter_nodes = [ + potential_parameter_node + for potential_parameter_node in self._tree.findall(element_path) + if potential_parameter_node.attrib["id"].startswith(parameter) + ] else: - parameter_nodes = self._tree.findall("./run/state/parameter[@id='%s']" % parameter) + parameter_nodes = self._tree.findall( + "./run/state/parameter[@id='%s']" % parameter + ) if len(parameter_nodes) == 0: - raise ValueError('No parameter with id %s (or starting with) was found.' % parameter) + raise ValueError( + "No parameter with id %s (or starting with) was found." % parameter + ) if len(parameter_nodes) > 1: - raise ValueError('More than one parameter with id %s (or starting with) was found.' % parameter) + raise ValueError( + "More than one parameter with id %s (or starting with) was found." + % parameter + ) return parameter_nodes[0] - def change_parameter_state_node(self, - parameter, - value=None, - dimension=None, - lower=None, - upper=None, - wild_card_ending=True): + def change_parameter_state_node( + self, + parameter, + value=None, + dimension=None, + lower=None, + upper=None, + wild_card_ending=True, + ): """ Change the values of the stateNode for a parameter. @@ -581,19 +644,23 @@ def change_parameter_state_node(self, """ if all(arg is None for arg in [value, dimension, lower, upper]): - raise ValueError('Either a value, dimension, lower or upper argument must be provided.') + raise ValueError( + "Either a value, dimension, lower or upper argument must be provided." + ) - parameter_node = self._search_for_parameter_in_element("./run/state/parameter", parameter, wild_card_ending) + parameter_node = self._search_for_parameter_in_element( + "./run/state/parameter", parameter, wild_card_ending + ) if value is not None: parameter_node.text = str(value) if dimension is not None: if not isinstance(dimension, int): - raise ValueError('Dimension must be an integer.') - parameter_node.set('dimension', str(dimension)) + raise ValueError("Dimension must be an integer.") + parameter_node.set("dimension", str(dimension)) if lower is not None: - parameter_node.set('lower', str(lower)) + parameter_node.set("lower", str(lower)) if upper is not None: - parameter_node.set('upper', str(upper)) + parameter_node.set("upper", str(upper)) def change_prior(self, parameter, distribution, wild_card_ending=True, **kwargs): """ @@ -612,81 +679,99 @@ def change_prior(self, parameter, distribution, wild_card_ending=True, **kwargs) """ parameter_node = self._search_for_parameter_in_element( - "./run/distribution/distribution/prior", - parameter, - wild_card_ending + "./run/distribution/distribution/prior", parameter, wild_card_ending ) - if distribution in ['lognorm', 'lognormal', - 'log norm', 'log normal', - 'log-norm', 'log-normal']: - distribution = 'LogNormal' - elif distribution in ['inversegamma', - 'inverse gamma', - 'inverse-gamma']: - distribution = 'InverseGamma' - elif distribution in ['LogNormal', 'InverseGamma', 'WeibullDistribution', 'LaplaceDistribution']: + if distribution in [ + "lognorm", + "lognormal", + "log norm", + "log normal", + "log-norm", + "log-normal", + ]: + distribution = "LogNormal" + elif distribution in ["inversegamma", "inverse gamma", "inverse-gamma"]: + distribution = "InverseGamma" + elif distribution in [ + "LogNormal", + "InverseGamma", + "WeibullDistribution", + "LaplaceDistribution", + ]: pass else: distribution = distribution.title() - if distribution in ['Weibull', 'Laplace']: - distribution = distribution + 'Distribution' + if distribution in ["Weibull", "Laplace"]: + distribution = distribution + "Distribution" if distribution not in self._distribution_args: raise ValueError( - 'Currently only the following distributions are supported: ' + - ', '.join(self._distribution_args.keys()) + '.' + "Currently only the following distributions are supported: " + + ", ".join(self._distribution_args.keys()) + + "." ) - if distribution == 'LogNormal': - if 'meanInRealSpace' not in kwargs: - kwargs['meanInRealSpace'] = 'false' - elif isinstance(kwargs['meanInRealSpace'], bool): - kwargs['meanInRealSpace'] = str(kwargs['meanInRealSpace']).lower() + if distribution == "LogNormal": + if "meanInRealSpace" not in kwargs: + kwargs["meanInRealSpace"] = "false" + elif isinstance(kwargs["meanInRealSpace"], bool): + kwargs["meanInRealSpace"] = str(kwargs["meanInRealSpace"]).lower() else: - raise TypeError('Argument meanInRealSpace must be a boolean or' + - ' not given as argument.') - - if distribution == 'WeibullDistribution': - if 'meanOne' not in kwargs: - kwargs['meanOne'] = 'false' - elif isinstance(kwargs['meanOne'], bool): - kwargs['meanOne'] = str(kwargs['meanOne']).lower() + raise TypeError( + "Argument meanInRealSpace must be a boolean or" + + " not given as argument." + ) + + if distribution == "WeibullDistribution": + if "meanOne" not in kwargs: + kwargs["meanOne"] = "false" + elif isinstance(kwargs["meanOne"], bool): + kwargs["meanOne"] = str(kwargs["meanOne"]).lower() else: - raise TypeError('Argument meanOne must be a boolean or' + - ' not given as argument.') + raise TypeError( + "Argument meanOne must be a boolean or" + " not given as argument." + ) - if 'offset' not in kwargs: - kwargs['offset'] = 0.0 + if "offset" not in kwargs: + kwargs["offset"] = 0.0 for key in kwargs: if key not in self._distribution_args[distribution]: raise ValueError( - key + - ' is not a parameter of the ' + - distribution + ' distribution.') + key + + " is not a parameter of the " + + distribution + + " distribution." + ) for arg in self._distribution_args[distribution]: if arg not in kwargs.keys(): - raise ValueError('%s has not being given as a kwarg.' % arg) + raise ValueError("%s has not being given as a kwarg." % arg) for keyword, value in kwargs.items(): - if keyword not in ['meanInRealSpace', 'meanOne']: - if isinstance(value, (int, float)): - raise TypeError('Argument %s must be an integer or float.' % keyword) + if keyword not in ["meanInRealSpace", "meanOne"]: + if not isinstance(value, (int, float)): + raise TypeError( + "Argument %s must be an integer or float." % keyword + ) kwargs = {key: str(value) for key, value in kwargs.items()} delete_child_nodes(parameter_node) - i_d = '_'.join([parameter, distribution]) - if distribution == 'Uniform': - self.change_parameter_state_node(parameter, **kwargs) - - if distribution in ['Poisson', 'WeibullDistribution']: - ET.SubElement(parameter_node, - 'distr', - id=i_d, - spec="beast.math.distributions."+distribution, - **kwargs) + i_d = "_".join([parameter, distribution]) + if distribution == "Uniform": + self.change_parameter_state_node( + parameter, lower=kwargs["lower"], upper=kwargs["upper"] + ) + + if distribution in ["Poisson", "WeibullDistribution"]: + ET.SubElement( + parameter_node, + "distr", + id=i_d, + spec="beast.math.distributions." + distribution, + **kwargs, + ) else: ET.SubElement(parameter_node, distribution, id=i_d, name="distr", **kwargs) @@ -702,7 +787,9 @@ def add_rate_change_dates(self, parameter, dates): """ if not isinstance(dates, (list, tuple, pd.Series, pd.DatetimeIndex)): - raise TypeError('dates must be a list, tuple pandas.Series or pandas.DatetimeIndex.') + raise TypeError( + "dates must be a list, tuple pandas.Series or pandas.DatetimeIndex." + ) year_decimals = [date_to_decimal(item) for item in dates] youngest_tip = max(self._age_by_short_id.values()) times = [youngest_tip - year_decimal for year_decimal in year_decimals] @@ -721,19 +808,20 @@ def add_rate_change_times(self, parameter, times): """ skyline_element = self._tree.find( - "./run/distribution/distribution/distribution[@spec='beast.evolution.speciation.BirthDeathSkylineModel']") + "./run/distribution/distribution/distribution[@spec='beast.evolution.speciation.BirthDeathSkylineModel']" + ) if skyline_element is None: raise ValueError( - 'No distribution of spec BirthDeathSkylineModel was found.' + - 'Currently this method only supports Birth Death Skyline Models.' + "No distribution of spec BirthDeathSkylineModel was found." + + "Currently this method only supports Birth Death Skyline Models." ) - rev_time_element = skyline_element.find('reverseTimeArrays') + rev_time_element = skyline_element.find("reverseTimeArrays") if rev_time_element is None: rev_time_array = [False, False, False, False, False] else: rev_time_array = rev_time_element.text - rev_time_array = rev_time_array.split(' ') - rev_time_array = [val in ['true', 'True', 'TRUE'] for val in rev_time_array] + rev_time_array = rev_time_array.split(" ") + rev_time_array = [val in ["true", "True", "TRUE"] for val in rev_time_array] del rev_time_element # Delete existing rev_time_element. if parameter == "birthRateChangeTimes": rev_time_array[0] = True @@ -743,27 +831,31 @@ def add_rate_change_times(self, parameter, times): rev_time_array[2] = True else: raise ValueError( - 'Currently this method only supports parameter being: ' + - 'birthRateChangeTimes (for changes in reproductive number), ' - 'deathRateChangeTimes (for changes in uninfectious rate) and ' + - 'samplingRateChangeTimes (for sampling proportion).' + "Currently this method only supports parameter being: " + + "birthRateChangeTimes (for changes in reproductive number), " + "deathRateChangeTimes (for changes in uninfectious rate) and " + + "samplingRateChangeTimes (for sampling proportion)." ) rev_time_array = [str(val).lower() for val in rev_time_array] - rev_time_array = ' '.join(rev_time_array) - ET.SubElement(skyline_element, - 'reverseTimeArrays', - spec="beast.core.parameter.BooleanParameter", - value=rev_time_array) + rev_time_array = " ".join(rev_time_array) + ET.SubElement( + skyline_element, + "reverseTimeArrays", + spec="beast.core.parameter.BooleanParameter", + value=rev_time_array, + ) parameter_element = skyline_element.find(parameter) if parameter_element is not None: del parameter_element # delete old parameter element_path if it exists. if not any(time == 0.0 for time in times): times.append(0.0) dimensions = len(times) - ET.SubElement(skyline_element, - parameter, - spec="parameter.RealParameter", - value=' '.join([str(time) for time in times])) + ET.SubElement( + skyline_element, + parameter, + spec="parameter.RealParameter", + value=" ".join([str(time) for time in times]), + ) self.change_parameter_state_node( - self._rate_change_to_param_dict[parameter], - dimension=dimensions) + self._rate_change_to_param_dict[parameter], dimension=dimensions + ) diff --git a/test/test_beast2.py b/test/test_beast2.py index eb80569..13326b8 100644 --- a/test/test_beast2.py +++ b/test/test_beast2.py @@ -6,7 +6,6 @@ from beast2xml import BEAST2XML from datetime import date, timedelta - try: from unittest.mock import mock_open, patch except ImportError: @@ -230,32 +229,38 @@ def test_one_sequence(self): # The sequence id with the default age of 0.0 must be in the traits. trait = elements["./run/state/tree/trait"] - self.assertEqual(trait.attrib['value'], "id1=0.0") + self.assertEqual(trait.attrib["value"], "id1=0.0") def test_sequence_id_age_regex(self): """ Using a sequence id age regex must result in the expected XML. """ - xml = BEAST2XML(clock_model=self.clock_model, sequence_id_age_regex="^.*_([0-9]+)") + xml = BEAST2XML( + clock_model=self.clock_model, sequence_id_age_regex="^.*_([0-9]+)" + ) xml.add_sequence(Read("id1_80_xxx", "ACTG")) tree = ET.ElementTree(ET.fromstring(xml.to_string())) elements = BEAST2XML.find_elements(tree) # The sequence id with the default age of 0.0 must be in the traits. trait = elements["./run/state/tree/trait"] - self.assertEqual(trait.attrib['value'], "id1_80_xxx=80.0") + self.assertEqual(trait.attrib["value"], "id1_80_xxx=80.0") def test_sequence_id_age_regex_non_matching(self): """ Using a sequence id age regex with a sequence id that does not match must result in a ValueError. """ - xml = BEAST2XML(clock_model=self.clock_model, sequence_id_age_regex="^.*_([0-9]+)") + xml = BEAST2XML( + clock_model=self.clock_model, sequence_id_age_regex="^.*_([0-9]+)" + ) error = ( r"^No sequence date or age could be found in 'id1' using the " r"sequence id date/age regular expressions\.$" ) - assertRaisesRegex(self, ValueError, error, xml.add_sequence, Read("id1", "ACTG")) + assertRaisesRegex( + self, ValueError, error, xml.add_sequence, Read("id1", "ACTG") + ) def test_sequence_id_regex_non_matching_not_an_error(self): """ @@ -274,7 +279,7 @@ def test_sequence_id_regex_non_matching_not_an_error(self): # The sequence id with the passed default age must be in the traits. trait = elements["./run/state/tree/trait"] - self.assertEqual(trait.attrib['value'], "id1_xxx=50") + self.assertEqual(trait.attrib["value"], "id1_xxx=50") def test_one_sequence_with_date_regex_and_date_unit_in_years(self): """ @@ -293,7 +298,7 @@ def test_one_sequence_with_date_regex_and_date_unit_in_years(self): # The sequence id with an age of ~2 years must be in the traits. trait = elements["./run/state/tree/trait"] # Note that the following is not exact! - trait_value = float(trait.attrib['value'].split("=")[1]) + trait_value = float(trait.attrib["value"].split("=")[1]) self.assertAlmostEqual(trait_value, 1.97, places=1) self.assertIs(None, trait.get("units")) @@ -318,7 +323,7 @@ def test_one_sequence_with_date_regex_and_date_unit_in_months(self): # The sequence id with an age of ~2 months must be in the traits. trait = elements["./run/state/tree/trait"] # Note that the following is not exact! - trait_value = float(trait.attrib['value'].split("=")[1]) + trait_value = float(trait.attrib["value"].split("=")[1]) self.assertAlmostEqual(trait_value, 1.9712, places=2) self.assertEqual("month", trait.get("units")) @@ -340,7 +345,7 @@ def test_one_sequence_with_date_regex_and_date_unit_in_days(self): # The sequence id with an age of 10 days must be in the traits. trait = elements["./run/state/tree/trait"] - self.assertEqual(trait.attrib['value'], id_+"=10.0") + self.assertEqual(trait.attrib["value"], id_ + "=10.0") self.assertEqual("day", trait.get("units")) def test_one_sequence_with_age(self): @@ -355,7 +360,7 @@ def test_one_sequence_with_age(self): # The sequence id with the given age must be in the traits. trait = elements["./run/state/tree/trait"] - self.assertEqual(trait.attrib['value'],"id1=44") + self.assertEqual(trait.attrib["value"], "id1=44") def test_one_sequence_with_age_added_together(self): """ @@ -369,7 +374,7 @@ def test_one_sequence_with_age_added_together(self): # The sequence id with the given age must be in the traits. trait = elements["./run/state/tree/trait"] - self.assertEqual(trait.attrib['value'], "id1=44") + self.assertEqual(trait.attrib["value"], "id1=44") def test_add_sequences(self): """ @@ -411,7 +416,7 @@ def test_add_sequences(self): # The sequence ids with the default age of 0.0 must be in the traits. trait = elements["./run/state/tree/trait"] - self.assertEqual(trait.attrib['value'], 'id1=0.0,id2=0.0,id3=0.0') + self.assertEqual(trait.attrib["value"], "id1=0.0,id2=0.0,id3=0.0") def test_chain_length(self): """ @@ -433,7 +438,7 @@ def test_default_age(self): # The sequence id with the default age of 33.0 must be in the traits. trait = elements["./run/state/tree/trait"] - self.assertEqual(trait.attrib['value'], 'id1=33.0') + self.assertEqual(trait.attrib["value"], "id1=33.0") def test_log_file_base_name(self): """