diff --git a/deepmd/dpmodel/model/linear_atomic_model.py b/deepmd/dpmodel/model/linear_atomic_model.py new file mode 100644 index 0000000000..dc7e9996c8 --- /dev/null +++ b/deepmd/dpmodel/model/linear_atomic_model.py @@ -0,0 +1,300 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import sys +from abc import ( + abstractmethod, +) +from typing import ( + Dict, + List, + Optional, + Tuple, + Union, +) + +import numpy as np + +from deepmd.dpmodel import ( + FittingOutputDef, + OutputVariableDef, +) +from deepmd.dpmodel.utils.nlist import ( + build_multiple_neighbor_list, + get_multiple_nlist_key, + nlist_distinguish_types, +) + +from .base_atomic_model import ( + BaseAtomicModel, +) +from .dp_atomic_model import ( + DPAtomicModel, +) +from .pairtab_atomic_model import ( + PairTabModel, +) + + +class LinearAtomicModel(BaseAtomicModel): + """Linear model make linear combinations of several existing models. + + Parameters + ---------- + models : list[DPAtomicModel or PairTabModel] + A list of models to be combined. PairTabModel must be used together with a DPAtomicModel. + """ + + def __init__( + self, + models: List[BaseAtomicModel], + **kwargs, + ): + super().__init__() + self.models = models + self.distinguish_type_list = [ + model.distinguish_types() for model in self.models + ] + + def distinguish_types(self) -> bool: + """If distinguish different types by sorting.""" + return False + + def get_rcut(self) -> float: + """Get the cut-off radius.""" + return max(self.get_model_rcuts()) + + def get_model_rcuts(self) -> List[float]: + """Get the cut-off radius for each individual models.""" + return [model.get_rcut() for model in self.models] + + def get_sel(self) -> List[int]: + return [max([model.get_nsel() for model in self.models])] + + def get_model_nsels(self) -> List[int]: + """Get the processed sels for each individual models. Not distinguishing types.""" + return [model.get_nsel() for model in self.models] + + def get_model_sels(self) -> List[Union[int, List[int]]]: + """Get the sels for each individual models.""" + return [model.get_sel() for model in self.models] + + def _sort_rcuts_sels(self) -> Tuple[List[float], List[int]]: + # sort the pair of rcut and sels in ascending order, first based on sel, then on rcut. + zipped = sorted( + zip(self.get_model_rcuts(), self.get_model_nsels()), + key=lambda x: (x[1], x[0]), + ) + return [p[0] for p in zipped], [p[1] for p in zipped] + + def forward_atomic( + self, + extended_coord, + extended_atype, + nlist, + mapping: Optional[np.ndarray] = None, + fparam: Optional[np.ndarray] = None, + aparam: Optional[np.ndarray] = None, + ) -> Dict[str, np.ndarray]: + """Return atomic prediction. + + Parameters + ---------- + extended_coord + coodinates in extended region, (nframes, nall * 3) + extended_atype + atomic type in extended region, (nframes, nall) + nlist + neighbor list, (nframes, nloc, nsel). + mapping + mapps the extended indices to local indices. + fparam + frame parameter. (nframes, ndf) + aparam + atomic parameter. (nframes, nloc, nda) + + Returns + ------- + result_dict + the result dict, defined by the fitting net output def. + """ + nframes, nloc, nnei = nlist.shape + extended_coord = extended_coord.reshape(nframes, -1, 3) + sorted_rcuts, sorted_sels = self._sort_rcuts_sels() + nlists = build_multiple_neighbor_list( + extended_coord, + nlist, + sorted_rcuts, + sorted_sels, + ) + raw_nlists = [ + nlists[get_multiple_nlist_key(rcut, sel)] + for rcut, sel in zip(self.get_model_rcuts(), self.get_model_nsels()) + ] + nlists_ = [ + nl if not dt else nlist_distinguish_types(nl, extended_atype, sel) + for dt, nl, sel in zip( + self.distinguish_type_list, raw_nlists, self.get_model_sels() + ) + ] + ener_list = [ + model.forward_atomic( + extended_coord, + extended_atype, + nl, + mapping, + fparam, + aparam, + )["energy"] + for model, nl in zip(self.models, nlists_) + ] + self.weights = self._compute_weight(extended_coord, extended_atype, nlists_) + self.atomic_bias = None + if self.atomic_bias is not None: + raise NotImplementedError("Need to add bias in a future PR.") + else: + fit_ret = { + "energy": np.sum(np.stack(ener_list) * np.stack(self.weights), axis=0), + } # (nframes, nloc, 1) + return fit_ret + + def fitting_output_def(self) -> FittingOutputDef: + return FittingOutputDef( + [ + OutputVariableDef( + name="energy", + shape=[1], + reduciable=True, + r_differentiable=True, + c_differentiable=True, + ) + ] + ) + + @staticmethod + def serialize(models) -> dict: + return { + "models": [model.serialize() for model in models], + "model_name": [model.__class__.__name__ for model in models], + } + + @staticmethod + def deserialize(data) -> List[BaseAtomicModel]: + model_names = data["model_name"] + models = [ + getattr(sys.modules[__name__], name).deserialize(model) + for name, model in zip(model_names, data["models"]) + ] + return models + + @abstractmethod + def _compute_weight( + self, + extended_coord: np.ndarray, + extended_atype: np.ndarray, + nlists_: List[np.ndarray], + ) -> np.ndarray: + """This should be a list of user defined weights that matches the number of models to be combined.""" + raise NotImplementedError + + +class DPZBLLinearAtomicModel(LinearAtomicModel): + """Model linearly combine a list of AtomicModels. + + Parameters + ---------- + models + This linear model should take a DPAtomicModel and a PairTable model. + """ + + def __init__( + self, + dp_model: DPAtomicModel, + zbl_model: PairTabModel, + sw_rmin: float, + sw_rmax: float, + smin_alpha: Optional[float] = 0.1, + **kwargs, + ): + models = [dp_model, zbl_model] + super().__init__(models, **kwargs) + self.dp_model = dp_model + self.zbl_model = zbl_model + + self.sw_rmin = sw_rmin + self.sw_rmax = sw_rmax + self.smin_alpha = smin_alpha + + def serialize(self) -> dict: + return { + "models": LinearAtomicModel.serialize([self.dp_model, self.zbl_model]), + "sw_rmin": self.sw_rmin, + "sw_rmax": self.sw_rmax, + "smin_alpha": self.smin_alpha, + } + + @classmethod + def deserialize(cls, data) -> "DPZBLLinearAtomicModel": + sw_rmin = data["sw_rmin"] + sw_rmax = data["sw_rmax"] + smin_alpha = data["smin_alpha"] + + dp_model, zbl_model = LinearAtomicModel.deserialize(data["models"]) + + return cls( + dp_model=dp_model, + zbl_model=zbl_model, + sw_rmin=sw_rmin, + sw_rmax=sw_rmax, + smin_alpha=smin_alpha, + ) + + def _compute_weight( + self, + extended_coord: np.ndarray, + extended_atype: np.ndarray, + nlists_: List[np.ndarray], + ) -> List[np.ndarray]: + """ZBL weight. + + Returns + ------- + List[np.ndarray] + the atomic ZBL weight for interpolation. (nframes, nloc, 1) + """ + assert ( + self.sw_rmax > self.sw_rmin + ), "The upper boundary `sw_rmax` must be greater than the lower boundary `sw_rmin`." + + dp_nlist = nlists_[0] + zbl_nlist = nlists_[1] + + zbl_nnei = zbl_nlist.shape[-1] + dp_nnei = dp_nlist.shape[-1] + + # use the larger rr based on nlist + nlist_larger = zbl_nlist if zbl_nnei >= dp_nnei else dp_nlist + masked_nlist = np.clip(nlist_larger, 0, None) + pairwise_rr = PairTabModel._get_pairwise_dist(extended_coord, masked_nlist) + + numerator = np.sum( + pairwise_rr * np.exp(-pairwise_rr / self.smin_alpha), axis=-1 + ) # masked nnei will be zero, no need to handle + denominator = np.sum( + np.where( + nlist_larger != -1, + np.exp(-pairwise_rr / self.smin_alpha), + np.zeros_like(nlist_larger), + ), + axis=-1, + ) # handle masked nnei. + sigma = numerator / denominator + u = (sigma - self.sw_rmin) / (self.sw_rmax - self.sw_rmin) + coef = np.zeros_like(u) + left_mask = sigma < self.sw_rmin + mid_mask = (self.sw_rmin <= sigma) & (sigma < self.sw_rmax) + right_mask = sigma >= self.sw_rmax + coef[left_mask] = 1 + smooth = -6 * u**5 + 15 * u**4 - 10 * u**3 + 1 + coef[mid_mask] = smooth[mid_mask] + coef[right_mask] = 0 + self.zbl_weight = coef + return [1 - np.expand_dims(coef, -1), np.expand_dims(coef, -1)] diff --git a/deepmd/dpmodel/model/pair_tab_model.py b/deepmd/dpmodel/model/pairtab_atomic_model.py similarity index 82% rename from deepmd/dpmodel/model/pair_tab_model.py rename to deepmd/dpmodel/model/pairtab_atomic_model.py index dc658d8662..d4feb970fb 100644 --- a/deepmd/dpmodel/model/pair_tab_model.py +++ b/deepmd/dpmodel/model/pairtab_atomic_model.py @@ -82,7 +82,10 @@ def fitting_output_def(self) -> FittingOutputDef: def get_rcut(self) -> float: return self.rcut - def get_sel(self) -> int: + def get_sel(self) -> List[int]: + return [self.sel] + + def get_nsel(self) -> int: return self.sel def distinguish_types(self) -> bool: @@ -109,21 +112,20 @@ def forward_atomic( extended_atype, nlist, mapping: Optional[np.ndarray] = None, - do_atomic_virial: bool = False, + fparam: Optional[np.ndarray] = None, + aparam: Optional[np.ndarray] = None, ) -> Dict[str, np.ndarray]: - self.nframes, self.nloc, self.nnei = nlist.shape - extended_coord = extended_coord.reshape(self.nframes, -1, 3) + nframes, nloc, nnei = nlist.shape + extended_coord = extended_coord.reshape(nframes, -1, 3) # this will mask all -1 in the nlist - masked_nlist = np.clip(nlist, 0, None) - - atype = extended_atype[:, : self.nloc] # (nframes, nloc) - pairwise_dr = self._get_pairwise_dist( - extended_coord - ) # (nframes, nall, nall, 3) - pairwise_rr = np.sqrt( - np.sum(np.power(pairwise_dr, 2), axis=-1) - ) # (nframes, nall, nall) + mask = nlist >= 0 + masked_nlist = nlist * mask + + atype = extended_atype[:, :nloc] # (nframes, nloc) + pairwise_rr = self._get_pairwise_dist( + extended_coord, masked_nlist + ) # (nframes, nloc, nnei) self.tab_data = self.tab_data.reshape( self.tab.ntypes, self.tab.ntypes, self.tab.nspline, 4 ) @@ -133,13 +135,13 @@ def forward_atomic( np.arange(extended_atype.shape[0])[:, None, None], masked_nlist ] - # slice rr to get (nframes, nloc, nnei) - rr = np.take_along_axis(pairwise_rr[:, : self.nloc, :], masked_nlist, 2) - raw_atomic_energy = self._pair_tabulated_inter(nlist, atype, j_type, rr) + raw_atomic_energy = self._pair_tabulated_inter( + nlist, atype, j_type, pairwise_rr + ) atomic_energy = 0.5 * np.sum( np.where(nlist != -1, raw_atomic_energy, np.zeros_like(raw_atomic_energy)), axis=-1, - ).reshape(self.nframes, self.nloc, 1) + ).reshape(nframes, nloc, 1) return {"energy": atomic_energy} @@ -178,17 +180,18 @@ def _pair_tabulated_inter( This function is used to calculate the pairwise energy between two atoms. It uses a table containing cubic spline coefficients calculated in PairTab. """ + nframes, nloc, nnei = nlist.shape rmin = self.tab_info[0] hh = self.tab_info[1] hi = 1.0 / hh - self.nspline = int(self.tab_info[2] + 0.1) + nspline = int(self.tab_info[2] + 0.1) uu = (rr - rmin) * hi # this is broadcasted to (nframes,nloc,nnei) # if nnei of atom 0 has -1 in the nlist, uu would be 0. # this is to handle the nlist where the mask is set to 0, so that we don't raise exception for those atoms. - uu = np.where(nlist != -1, uu, self.nspline + 1) + uu = np.where(nlist != -1, uu, nspline + 1) if np.any(uu < 0): raise Exception("coord go beyond table lower boundary") @@ -197,34 +200,42 @@ def _pair_tabulated_inter( uu -= idx table_coef = self._extract_spline_coefficient( - i_type, j_type, idx, self.tab_data, self.nspline + i_type, j_type, idx, self.tab_data, nspline ) - table_coef = table_coef.reshape(self.nframes, self.nloc, self.nnei, 4) - ener = self._calcualte_ener(table_coef, uu) + table_coef = table_coef.reshape(nframes, nloc, nnei, 4) + ener = self._calculate_ener(table_coef, uu) # here we need to overwrite energy to zero at rcut and beyond. mask_beyond_rcut = rr >= self.rcut # also overwrite values beyond extrapolation to zero - extrapolation_mask = rr >= self.tab.rmin + self.nspline * self.tab.hh + extrapolation_mask = rr >= self.tab.rmin + nspline * self.tab.hh ener[mask_beyond_rcut] = 0 ener[extrapolation_mask] = 0 return ener @staticmethod - def _get_pairwise_dist(coords: np.ndarray) -> np.ndarray: + def _get_pairwise_dist(coords: np.ndarray, nlist: np.ndarray) -> np.ndarray: """Get pairwise distance `dr`. Parameters ---------- coords : np.ndarray - The coordinate of the atoms shape of (nframes, nall, 3). + The coordinate of the atoms, shape of (nframes, nall, 3). + nlist + The masked nlist, shape of (nframes, nloc, nnei). Returns ------- np.ndarray - The pairwise distance between the atoms (nframes, nall, nall, 3). + The pairwise distance between the atoms (nframes, nloc, nnei). """ - return np.expand_dims(coords, 2) - np.expand_dims(coords, 1) + batch_indices = np.arange(nlist.shape[0])[:, None, None] + neighbor_atoms = coords[batch_indices, nlist] + loc_atoms = coords[:, : nlist.shape[1], :] + pairwise_dr = loc_atoms[:, :, None, :] - neighbor_atoms + pairwise_rr = np.sqrt(np.sum(np.power(pairwise_dr, 2), axis=-1)) + + return pairwise_rr @staticmethod def _extract_spline_coefficient( @@ -279,7 +290,7 @@ def _extract_spline_coefficient( return final_coef @staticmethod - def _calcualte_ener(coef: np.ndarray, uu: np.ndarray) -> np.ndarray: + def _calculate_ener(coef: np.ndarray, uu: np.ndarray) -> np.ndarray: """Calculate energy using spline coeeficients. Parameters diff --git a/deepmd/dpmodel/utils/nlist.py b/deepmd/dpmodel/utils/nlist.py index bc6592d52b..657d6ecee2 100644 --- a/deepmd/dpmodel/utils/nlist.py +++ b/deepmd/dpmodel/utils/nlist.py @@ -169,7 +169,7 @@ def build_multiple_neighbor_list( nall = coord1.shape[1] coord0 = coord1[:, :nloc, :] nlist_mask = nlist == -1 - tnlist_0 = nlist + tnlist_0 = nlist.copy() tnlist_0[nlist_mask] = 0 index = np.tile(tnlist_0.reshape(nb, nloc * nsel, 1), [1, 1, 3]) coord2 = np.take_along_axis(coord1, index, axis=1).reshape(nb, nloc, nsel, 3) diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index c4de02ed20..6cbab5af4d 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -4,18 +4,62 @@ from deepmd.pt.model.descriptor.descriptor import ( Descriptor, ) +from deepmd.pt.model.model.dp_atomic_model import ( + DPAtomicModel, +) +from deepmd.pt.model.model.pairtab_atomic_model import ( + PairTabModel, +) from deepmd.pt.model.task import ( Fitting, ) from .ener import ( EnergyModel, + ZBLModel, ) from .model import ( BaseModel, ) +def get_zbl_model(model_params, sampled=None): + model_params = copy.deepcopy(model_params) + ntypes = len(model_params["type_map"]) + # descriptor + model_params["descriptor"]["ntypes"] = ntypes + descriptor = Descriptor(**model_params["descriptor"]) + # fitting + fitting_net = model_params.get("fitting_net", None) + fitting_net["type"] = fitting_net.get("type", "ener") + fitting_net["ntypes"] = descriptor.get_ntypes() + fitting_net["distinguish_types"] = descriptor.distinguish_types() + fitting_net["embedding_width"] = descriptor.get_dim_out() + grad_force = "direct" not in fitting_net["type"] + if not grad_force: + fitting_net["out_dim"] = descriptor.get_dim_emb() + if "ener" in fitting_net["type"]: + fitting_net["return_energy"] = True + fitting = Fitting(**fitting_net) + dp_model = DPAtomicModel( + descriptor, fitting, type_map=model_params["type_map"], resuming=True + ) + # pairtab + filepath = model_params["use_srtab"] + pt_model = PairTabModel( + filepath, model_params["descriptor"]["rcut"], model_params["descriptor"]["sel"] + ) + + rmin = model_params["sw_rmin"] + rmax = model_params["sw_rmax"] + return ZBLModel( + dp_model, + pt_model, + rmin, + rmax, + ) + + def get_model(model_params, sampled=None): model_params = copy.deepcopy(model_params) ntypes = len(model_params["type_map"]) diff --git a/deepmd/pt/model/model/ener.py b/deepmd/pt/model/model/ener.py index 9a6e60d963..ea35cf5a82 100644 --- a/deepmd/pt/model/model/ener.py +++ b/deepmd/pt/model/model/ener.py @@ -9,11 +9,84 @@ from .dp_atomic_model import ( DPAtomicModel, ) +from .linear_atomic_model import ( + DPZBLLinearAtomicModel, +) from .make_model import ( make_model, ) DPModel = make_model(DPAtomicModel) +ZBLModel_ = make_model(DPZBLLinearAtomicModel) + + +class ZBLModel(ZBLModel_): + model_type = "ener" + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + def forward( + self, + coord, + atype, + box: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + ) -> Dict[str, torch.Tensor]: + model_ret = self.forward_common( + coord, + atype, + box, + ) + + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + if self.do_grad("energy"): + model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) + if do_atomic_virial: + model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3) + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + else: + model_predict["force"] = model_ret["dforce"] + return model_predict + + def forward_lower( + self, + extended_coord, + extended_atype, + nlist, + mapping: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + ): + model_ret = self.forward_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + ) + + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + if self.do_grad("energy"): + model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2) + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + if do_atomic_virial: + model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze( + -2 + ) + else: + assert model_ret["dforce"] is not None + model_predict["dforce"] = model_ret["dforce"] + model_predict = model_ret + return model_predict class EnergyModel(DPModel): diff --git a/deepmd/pt/model/model/linear_atomic_model.py b/deepmd/pt/model/model/linear_atomic_model.py new file mode 100644 index 0000000000..8b50f5e4f5 --- /dev/null +++ b/deepmd/pt/model/model/linear_atomic_model.py @@ -0,0 +1,315 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import sys +from abc import ( + abstractmethod, +) +from typing import ( + Dict, + List, + Optional, + Tuple, +) + +import torch + +from deepmd.dpmodel import ( + FittingOutputDef, + OutputVariableDef, +) +from deepmd.pt.utils.nlist import ( + build_multiple_neighbor_list, + get_multiple_nlist_key, + nlist_distinguish_types, +) + +from .base_atomic_model import ( + BaseAtomicModel, +) +from .dp_atomic_model import ( + DPAtomicModel, +) +from .model import ( + BaseModel, +) +from .pairtab_atomic_model import ( + PairTabModel, +) + + +class LinearAtomicModel(BaseModel, BaseAtomicModel): + """Linear model make linear combinations of several existing models. + + Parameters + ---------- + models : list[DPAtomicModel or PairTabModel] + A list of models to be combined. PairTabModel must be used together with a DPAtomicModel. + """ + + def __init__( + self, + models: List[BaseAtomicModel], + **kwargs, + ): + super().__init__() + self.models = torch.nn.ModuleList(models) + self.atomic_bias = None + self.distinguish_type_list = [ + model.distinguish_types() for model in self.models + ] + + def distinguish_types(self) -> bool: + """If distinguish different types by sorting.""" + return False + + def get_rcut(self) -> float: + """Get the cut-off radius.""" + return max(self.get_model_rcuts()) + + def get_model_rcuts(self) -> List[float]: + """Get the cut-off radius for each individual models.""" + return [model.get_rcut() for model in self.models] + + def get_sel(self) -> List[int]: + return [max([model.get_nsel() for model in self.models])] + + def get_model_nsels(self) -> List[int]: + """Get the processed sels for each individual models. Not distinguishing types.""" + return [model.get_nsel() for model in self.models] + + def get_model_sels(self) -> List[List[int]]: + """Get the sels for each individual models.""" + return [model.get_sel() for model in self.models] + + def _sort_rcuts_sels(self) -> Tuple[List[float], List[int]]: + # sort the pair of rcut and sels in ascending order, first based on sel, then on rcut. + rcuts = torch.tensor(self.get_model_rcuts(), dtype=torch.float64) + nsels = torch.tensor(self.get_model_nsels()) + zipped = torch.stack([torch.tensor(rcuts), torch.tensor(nsels)], dim=0).T + inner_sorting = torch.argsort(zipped[:, 1], dim=0) + inner_sorted = zipped[inner_sorting] + outer_sorting = torch.argsort(inner_sorted[:, 0], stable=True) + outer_sorted = inner_sorted[outer_sorting] + sorted_rcuts: List[float] = outer_sorted[:, 0].tolist() + sorted_sels: List[int] = outer_sorted[:, 1].to(torch.int64).tolist() + return sorted_rcuts, sorted_sels + + def forward_atomic( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + ) -> Dict[str, torch.Tensor]: + """Return atomic prediction. + + Parameters + ---------- + extended_coord + coodinates in extended region, (nframes, nall * 3) + extended_atype + atomic type in extended region, (nframes, nall) + nlist + neighbor list, (nframes, nloc, nsel). + mapping + mapps the extended indices to local indices. + fparam + frame parameter. (nframes, ndf) + aparam + atomic parameter. (nframes, nloc, nda) + + Returns + ------- + result_dict + the result dict, defined by the fitting net output def. + """ + nframes, nloc, nnei = nlist.shape + if self.do_grad(): + extended_coord.requires_grad_(True) + extended_coord = extended_coord.view(nframes, -1, 3) + sorted_rcuts, sorted_sels = self._sort_rcuts_sels() + nlists = build_multiple_neighbor_list( + extended_coord, + nlist, + sorted_rcuts, + sorted_sels, + ) + raw_nlists = [ + nlists[get_multiple_nlist_key(rcut, sel)] + for rcut, sel in zip(self.get_model_rcuts(), self.get_model_nsels()) + ] + nlists_ = [ + nl if not dt else nlist_distinguish_types(nl, extended_atype, sel) + for dt, nl, sel in zip( + self.distinguish_type_list, raw_nlists, self.get_model_sels() + ) + ] + ener_list = [] + + for i, model in enumerate(self.models): + ener_list.append( + model.forward_atomic( + extended_coord, + extended_atype, + nlists_[i], + mapping, + fparam, + aparam, + )["energy"] + ) + + weights = self._compute_weight(extended_coord, extended_atype, nlists_) + + if self.atomic_bias is not None: + raise NotImplementedError("Need to add bias in a future PR.") + else: + fit_ret = { + "energy": torch.sum( + torch.stack(ener_list) * torch.stack(weights), dim=0 + ), + } # (nframes, nloc, 1) + return fit_ret + + def fitting_output_def(self) -> FittingOutputDef: + return FittingOutputDef( + [ + OutputVariableDef( + name="energy", + shape=[1], + reduciable=True, + r_differentiable=True, + c_differentiable=True, + ) + ] + ) + + @staticmethod + def serialize(models) -> dict: + return { + "models": [model.serialize() for model in models], + "model_name": [model.__class__.__name__ for model in models], + } + + @staticmethod + def deserialize(data) -> List[BaseAtomicModel]: + model_names = data["model_name"] + models = [ + getattr(sys.modules[__name__], name).deserialize(model) + for name, model in zip(model_names, data["models"]) + ] + return models + + @abstractmethod + def _compute_weight( + self, extended_coord, extended_atype, nlists_ + ) -> List[torch.Tensor]: + """This should be a list of user defined weights that matches the number of models to be combined.""" + raise NotImplementedError + + +class DPZBLLinearAtomicModel(LinearAtomicModel): + """Model linearly combine a list of AtomicModels. + + Parameters + ---------- + models + This linear model should take a DPAtomicModel and a PairTable model. + """ + + def __init__( + self, + dp_model: DPAtomicModel, + zbl_model: PairTabModel, + sw_rmin: float, + sw_rmax: float, + smin_alpha: Optional[float] = 0.1, + **kwargs, + ): + models = [dp_model, zbl_model] + super().__init__(models, **kwargs) + self.dp_model = dp_model + self.zbl_model = zbl_model + + self.sw_rmin = sw_rmin + self.sw_rmax = sw_rmax + self.smin_alpha = smin_alpha + + # this is a placeholder being updated in _compute_weight, to handle Jit attribute init error. + self.zbl_weight = torch.empty(0, dtype=torch.float64) + + def serialize(self) -> dict: + return { + "models": LinearAtomicModel.serialize([self.dp_model, self.zbl_model]), + "sw_rmin": self.sw_rmin, + "sw_rmax": self.sw_rmax, + "smin_alpha": self.smin_alpha, + } + + @classmethod + def deserialize(cls, data) -> "DPZBLLinearAtomicModel": + sw_rmin = data["sw_rmin"] + sw_rmax = data["sw_rmax"] + smin_alpha = data["smin_alpha"] + + dp_model, zbl_model = LinearAtomicModel.deserialize(data["models"]) + + return cls( + dp_model=dp_model, + zbl_model=zbl_model, + sw_rmin=sw_rmin, + sw_rmax=sw_rmax, + smin_alpha=smin_alpha, + ) + + def _compute_weight( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlists_: List[torch.Tensor], + ) -> List[torch.Tensor]: + """ZBL weight. + + Returns + ------- + List[torch.Tensor] + the atomic ZBL weight for interpolation. (nframes, nloc, 1) + """ + assert ( + self.sw_rmax > self.sw_rmin + ), "The upper boundary `sw_rmax` must be greater than the lower boundary `sw_rmin`." + + dp_nlist = nlists_[0] + zbl_nlist = nlists_[1] + + zbl_nnei = zbl_nlist.shape[-1] + dp_nnei = dp_nlist.shape[-1] + + # use the larger rr based on nlist + nlist_larger = zbl_nlist if zbl_nnei >= dp_nnei else dp_nlist + masked_nlist = torch.clamp(nlist_larger, 0) + pairwise_rr = PairTabModel._get_pairwise_dist(extended_coord, masked_nlist) + numerator = torch.sum( + pairwise_rr * torch.exp(-pairwise_rr / self.smin_alpha), dim=-1 + ) # masked nnei will be zero, no need to handle + denominator = torch.sum( + torch.where( + nlist_larger != -1, + torch.exp(-pairwise_rr / self.smin_alpha), + torch.zeros_like(nlist_larger), + ), + dim=-1, + ) # handle masked nnei. + + sigma = numerator / denominator # nfrmes, nloc + u = (sigma - self.sw_rmin) / (self.sw_rmax - self.sw_rmin) + coef = torch.zeros_like(u) + left_mask = sigma < self.sw_rmin + mid_mask = (self.sw_rmin <= sigma) & (sigma < self.sw_rmax) + right_mask = sigma >= self.sw_rmax + coef[left_mask] = 1 + smooth = -6 * u**5 + 15 * u**4 - 10 * u**3 + 1 + coef[mid_mask] = smooth[mid_mask] + coef[right_mask] = 0 + self.zbl_weight = coef # nframes, nloc + return [1 - coef.unsqueeze(-1), coef.unsqueeze(-1)] # to match the model order. diff --git a/deepmd/pt/model/model/pair_tab_model.py b/deepmd/pt/model/model/pairtab_atomic_model.py similarity index 79% rename from deepmd/pt/model/model/pair_tab_model.py rename to deepmd/pt/model/model/pairtab_atomic_model.py index 83089f86b4..98215191c1 100644 --- a/deepmd/pt/model/model/pair_tab_model.py +++ b/deepmd/pt/model/model/pairtab_atomic_model.py @@ -54,7 +54,7 @@ def __init__( super().__init__() self.tab_file = tab_file self.rcut = rcut - self.tab = PairTab(self.tab_file, rcut=rcut) + self.tab = self._set_pairtab(tab_file, rcut) # handle deserialization with no input file if self.tab_file is not None: @@ -78,6 +78,10 @@ def __init__( else: raise TypeError("sel must be int or list[int]") + @torch.jit.ignore + def _set_pairtab(self, tab_file: str, rcut: float) -> PairTab: + return PairTab(tab_file, rcut) + def fitting_output_def(self) -> FittingOutputDef: return FittingOutputDef( [ @@ -94,7 +98,10 @@ def fitting_output_def(self) -> FittingOutputDef: def get_rcut(self) -> float: return self.rcut - def get_sel(self) -> int: + def get_sel(self) -> List[int]: + return [self.sel] + + def get_nsel(self) -> int: return self.sel def distinguish_types(self) -> bool: @@ -117,39 +124,41 @@ def deserialize(cls, data) -> "PairTabModel": def forward_atomic( self, - extended_coord, - extended_atype, - nlist, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, mapping: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, ) -> Dict[str, torch.Tensor]: - self.nframes, self.nloc, self.nnei = nlist.shape - extended_coord = extended_coord.view(self.nframes, -1, 3) + nframes, nloc, nnei = nlist.shape + extended_coord = extended_coord.view(nframes, -1, 3) + if self.do_grad(): + extended_coord.requires_grad_(True) # this will mask all -1 in the nlist - masked_nlist = torch.clamp(nlist, 0) - - atype = extended_atype[:, : self.nloc] # (nframes, nloc) - pairwise_dr = self._get_pairwise_dist( - extended_coord - ) # (nframes, nall, nall, 3) - pairwise_rr = pairwise_dr.pow(2).sum(-1).sqrt() # (nframes, nall, nall) + mask = nlist >= 0 + masked_nlist = nlist * mask + atype = extended_atype[:, :nloc] # (nframes, nloc) + pairwise_rr = self._get_pairwise_dist( + extended_coord, masked_nlist + ) # (nframes, nloc, nnei) self.tab_data = self.tab_data.view( - self.tab.ntypes, self.tab.ntypes, self.tab.nspline, 4 + int(self.tab_info[-1]), int(self.tab_info[-1]), int(self.tab_info[2]), 4 ) - # to calculate the atomic_energy, we need 3 tensors, i_type, j_type, rr + # to calculate the atomic_energy, we need 3 tensors, i_type, j_type, pairwise_rr # i_type : (nframes, nloc), this is atype. # j_type : (nframes, nloc, nnei) j_type = extended_atype[ torch.arange(extended_atype.size(0))[:, None, None], masked_nlist ] - # slice rr to get (nframes, nloc, nnei) - rr = torch.gather(pairwise_rr[:, : self.nloc, :], 2, masked_nlist) - - raw_atomic_energy = self._pair_tabulated_inter(nlist, atype, j_type, rr) + raw_atomic_energy = self._pair_tabulated_inter( + nlist, atype, j_type, pairwise_rr + ) atomic_energy = 0.5 * torch.sum( torch.where( @@ -195,17 +204,18 @@ def _pair_tabulated_inter( This function is used to calculate the pairwise energy between two atoms. It uses a table containing cubic spline coefficients calculated in PairTab. """ + nframes, nloc, nnei = nlist.shape rmin = self.tab_info[0] hh = self.tab_info[1] hi = 1.0 / hh - self.nspline = int(self.tab_info[2] + 0.1) + nspline = int(self.tab_info[2] + 0.1) uu = (rr - rmin) * hi # this is broadcasted to (nframes,nloc,nnei) # if nnei of atom 0 has -1 in the nlist, uu would be 0. # this is to handle the nlist where the mask is set to 0, so that we don't raise exception for those atoms. - uu = torch.where(nlist != -1, uu, self.nspline + 1) + uu = torch.where(nlist != -1, uu, nspline + 1) if torch.any(uu < 0): raise Exception("coord go beyond table lower boundary") @@ -215,57 +225,44 @@ def _pair_tabulated_inter( uu -= idx table_coef = self._extract_spline_coefficient( - i_type, j_type, idx, self.tab_data, self.nspline + i_type, j_type, idx, self.tab_data, nspline ) - table_coef = table_coef.view(self.nframes, self.nloc, self.nnei, 4) - ener = self._calcualte_ener(table_coef, uu) + table_coef = table_coef.view(nframes, nloc, nnei, 4) + ener = self._calculate_ener(table_coef, uu) # here we need to overwrite energy to zero at rcut and beyond. mask_beyond_rcut = rr >= self.rcut # also overwrite values beyond extrapolation to zero - extrapolation_mask = rr >= self.tab.rmin + self.nspline * self.tab.hh + extrapolation_mask = rr >= rmin + nspline * hh ener[mask_beyond_rcut] = 0 ener[extrapolation_mask] = 0 return ener @staticmethod - def _get_pairwise_dist(coords: torch.Tensor) -> torch.Tensor: + def _get_pairwise_dist(coords: torch.Tensor, nlist: torch.Tensor) -> torch.Tensor: """Get pairwise distance `dr`. Parameters ---------- coords : torch.Tensor - The coordinate of the atoms shape of (nframes, nall, 3). + The coordinate of the atoms, shape of (nframes, nall, 3). + nlist + The masked nlist, shape of (nframes, nloc, nnei) Returns ------- torch.Tensor - The pairwise distance between the atoms (nframes, nall, nall, 3). - - Examples - -------- - coords = torch.tensor([[ - [0,0,0], - [1,3,5], - [2,4,6] - ]]) - - dist = tensor([[ - [[ 0, 0, 0], - [-1, -3, -5], - [-2, -4, -6]], - - [[ 1, 3, 5], - [ 0, 0, 0], - [-1, -1, -1]], - - [[ 2, 4, 6], - [ 1, 1, 1], - [ 0, 0, 0]] - ]]) + The pairwise distance between the atoms (nframes, nloc, nnei). """ - return coords.unsqueeze(2) - coords.unsqueeze(1) + nframes, nloc, nnei = nlist.shape + coord_l = coords[:, :nloc].view(nframes, -1, 1, 3) + index = nlist.view(nframes, -1).unsqueeze(-1).expand(-1, -1, 3) + coord_r = torch.gather(coords, 1, index) + coord_r = coord_r.view(nframes, nloc, nnei, 3) + diff = coord_r - coord_l + pairwise_rr = torch.linalg.norm(diff, dim=-1, keepdim=True).squeeze(-1) + return pairwise_rr @staticmethod def _extract_spline_coefficient( @@ -316,7 +313,7 @@ def _extract_spline_coefficient( return final_coef @staticmethod - def _calcualte_ener(coef: torch.Tensor, uu: torch.Tensor) -> torch.Tensor: + def _calculate_ener(coef: torch.Tensor, uu: torch.Tensor) -> torch.Tensor: """Calculate energy using spline coeeficients. Parameters diff --git a/source/tests/common/dpmodel/test_linear_atomic_model.py b/source/tests/common/dpmodel/test_linear_atomic_model.py new file mode 100644 index 0000000000..6dcff97d74 --- /dev/null +++ b/source/tests/common/dpmodel/test_linear_atomic_model.py @@ -0,0 +1,167 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from unittest.mock import ( + patch, +) + +import numpy as np + +from deepmd.dpmodel.descriptor.se_e2_a import ( + DescrptSeA, +) +from deepmd.dpmodel.fitting.invar_fitting import ( + InvarFitting, +) +from deepmd.dpmodel.model.dp_atomic_model import ( + DPAtomicModel, +) +from deepmd.dpmodel.model.linear_atomic_model import ( + DPZBLLinearAtomicModel, +) +from deepmd.dpmodel.model.pairtab_atomic_model import ( + PairTabModel, +) + + +class TestWeightCalculation(unittest.TestCase): + @patch("numpy.loadtxt") + def test_pairwise(self, mock_loadtxt): + file_path = "dummy_path" + mock_loadtxt.return_value = np.array( + [ + [0.05, 1.0, 2.0, 3.0], + [0.1, 0.8, 1.6, 2.4], + [0.15, 0.5, 1.0, 1.5], + [0.2, 0.25, 0.4, 0.75], + [0.25, 0.0, 0.0, 0.0], + ] + ) + extended_atype = np.array([[0, 0]]) + nlist = np.array([[[1], [-1]]]) + + ds = DescrptSeA( + rcut=0.3, + rcut_smth=0.4, + sel=[3], + ) + ft = InvarFitting( + "energy", + 2, + ds.get_dim_out(), + 1, + distinguish_types=ds.distinguish_types(), + ) + + type_map = ["foo", "bar"] + zbl_model = PairTabModel(tab_file=file_path, rcut=0.3, sel=2) + dp_model = DPAtomicModel(ds, ft, type_map=type_map) + + wgt_model = DPZBLLinearAtomicModel( + dp_model, + zbl_model, + sw_rmin=0.1, + sw_rmax=0.25, + ) + wgt_res = [] + for dist in np.linspace(0.05, 0.3, 10): + extended_coord = np.array( + [ + [ + [0.0, 0.0, 0.0], + [0.0, dist, 0.0], + ], + ] + ) + + wgt_model.forward_atomic(extended_coord, extended_atype, nlist) + + wgt_res.append(wgt_model.zbl_weight) + results = np.stack(wgt_res).reshape(10, 2) + excepted_res = np.array( + [ + [1.0, 0.0], + [1.0, 0.0], + [0.9995, 0.0], + [0.9236, 0.0], + [0.6697, 0.0], + [0.3303, 0.0], + [0.0764, 0.0], + [0.0005, 0.0], + [0.0, 0.0], + [0.0, 0.0], + ], + ) + np.testing.assert_allclose(results, excepted_res, rtol=0.0001, atol=0.0001) + + +class TestIntegration(unittest.TestCase): + @patch("numpy.loadtxt") + def setUp(self, mock_loadtxt): + self.nloc = 3 + self.nall = 4 + self.nf, self.nt = 1, 2 + self.coord_ext = np.array( + [ + [0, 0, 0], + [0, 1, 0], + [0, 0, 1], + [0, -2, 0], + ], + dtype=np.float64, + ).reshape([1, self.nall * 3]) + self.atype_ext = np.array([0, 0, 1, 0], dtype=int).reshape([1, self.nall]) + self.sel = [5, 2] + self.nlist = np.array( + [ + [1, 3, -1, -1, -1, 2, -1], + [0, -1, -1, -1, -1, 2, -1], + [0, 1, -1, -1, -1, -1, -1], + ], + dtype=int, + ).reshape([1, self.nloc, sum(self.sel)]) + self.rcut = 0.4 + self.rcut_smth = 2.2 + + file_path = "dummy_path" + mock_loadtxt.return_value = np.array( + [ + [0.005, 1.0, 2.0, 3.0], + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + ] + ) + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + distinguish_types=ds.distinguish_types(), + ) + type_map = ["foo", "bar"] + dp_model = DPAtomicModel(ds, ft, type_map=type_map) + zbl_model = PairTabModel(file_path, self.rcut, sum(self.sel)) + self.md0 = DPZBLLinearAtomicModel( + dp_model, + zbl_model, + sw_rmin=0.1, + sw_rmax=0.25, + ) + self.md1 = DPZBLLinearAtomicModel.deserialize(self.md0.serialize()) + + def test_self_consistency(self): + ret0 = self.md0.forward_atomic(self.coord_ext, self.atype_ext, self.nlist) + ret1 = self.md1.forward_atomic(self.coord_ext, self.atype_ext, self.nlist) + np.testing.assert_allclose( + ret0["energy"], + ret1["energy"], + ) + + +if __name__ == "__main__": + unittest.main(warnings="ignore") diff --git a/source/tests/common/dpmodel/test_pairtab.py b/source/tests/common/dpmodel/test_pairtab_atomic_model.py similarity index 97% rename from source/tests/common/dpmodel/test_pairtab.py rename to source/tests/common/dpmodel/test_pairtab_atomic_model.py index 3713d33510..f1e7bd257c 100644 --- a/source/tests/common/dpmodel/test_pairtab.py +++ b/source/tests/common/dpmodel/test_pairtab_atomic_model.py @@ -6,7 +6,7 @@ import numpy as np -from deepmd.dpmodel.model.pair_tab_model import ( +from deepmd.dpmodel.model.pairtab_atomic_model import ( PairTabModel, ) @@ -199,5 +199,6 @@ def test_extrapolation_nonzero_rmax(self, mock_loadtxt) -> None: np.testing.assert_allclose(results, expected_result, 0.0001, 0.0001) - if __name__ == "__main__": - unittest.main() + +if __name__ == "__main__": + unittest.main(warnings="ignore") diff --git a/source/tests/common/dpmodel/test_pairtab_preprocess.py b/source/tests/common/dpmodel/test_pairtab_preprocess.py index 26f96a3ca4..da3b9251f7 100644 --- a/source/tests/common/dpmodel/test_pairtab_preprocess.py +++ b/source/tests/common/dpmodel/test_pairtab_preprocess.py @@ -273,3 +273,7 @@ def test_preprocess(self): rtol=1e-03, atol=1e-03, ) + + +if __name__ == "__main__": + unittest.main(warnings="ignore") diff --git a/source/tests/pt/model/test_autodiff.py b/source/tests/pt/model/test_autodiff.py index 8840fbdd4c..24dc69458d 100644 --- a/source/tests/pt/model/test_autodiff.py +++ b/source/tests/pt/model/test_autodiff.py @@ -7,6 +7,7 @@ from deepmd.pt.model.model import ( get_model, + get_zbl_model, ) from deepmd.pt.utils import ( env, @@ -20,6 +21,7 @@ model_dpa1, model_dpa2, model_se_e2_a, + model_zbl, ) @@ -190,3 +192,19 @@ def setUp(self): model_params = copy.deepcopy(model_dpa2) self.type_split = True self.model = get_model(model_params, sampled).to(env.DEVICE) + + +class TestEnergyModelZBLForce(unittest.TestCase, ForceTest): + def setUp(self): + model_params = copy.deepcopy(model_zbl) + sampled = make_sample(model_params) + self.type_split = False + self.model = get_zbl_model(model_params, sampled).to(env.DEVICE) + + +class TestEnergyModelZBLVirial(unittest.TestCase, VirialTest): + def setUp(self): + model_params = copy.deepcopy(model_zbl) + sampled = make_sample(model_params) + self.type_split = False + self.model = get_zbl_model(model_params, sampled).to(env.DEVICE) diff --git a/source/tests/pt/model/test_linear_atomic_model.py b/source/tests/pt/model/test_linear_atomic_model.py new file mode 100644 index 0000000000..211b1f8215 --- /dev/null +++ b/source/tests/pt/model/test_linear_atomic_model.py @@ -0,0 +1,183 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from unittest.mock import ( + patch, +) + +import numpy as np +import torch + +from deepmd.dpmodel.model.linear_atomic_model import ( + DPZBLLinearAtomicModel as DPDPZBLLinearAtomicModel, +) +from deepmd.pt.model.descriptor.se_a import ( + DescrptSeA, +) +from deepmd.pt.model.model.dp_atomic_model import ( + DPAtomicModel, +) +from deepmd.pt.model.model.ener import ( + ZBLModel, +) +from deepmd.pt.model.model.linear_atomic_model import ( + DPZBLLinearAtomicModel, +) +from deepmd.pt.model.model.pairtab_atomic_model import ( + PairTabModel, +) +from deepmd.pt.model.task.ener import ( + InvarFitting, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) + +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +class TestWeightCalculation(unittest.TestCase): + @patch("numpy.loadtxt") + def test_pairwise(self, mock_loadtxt): + file_path = "dummy_path" + mock_loadtxt.return_value = np.array( + [ + [0.05, 1.0, 2.0, 3.0], + [0.1, 0.8, 1.6, 2.4], + [0.15, 0.5, 1.0, 1.5], + [0.2, 0.25, 0.4, 0.75], + [0.25, 0.0, 0.0, 0.0], + ] + ) + extended_atype = torch.tensor([[0, 0]]) + nlist = torch.tensor([[[1], [-1]]]) + + ds = DescrptSeA( + rcut=0.3, + rcut_smth=0.4, + sel=[3], + ).to(env.DEVICE) + ft = InvarFitting( + "energy", + 2, + ds.get_dim_out(), + 1, + distinguish_types=ds.distinguish_types(), + ).to(env.DEVICE) + + type_map = ["foo", "bar"] + zbl_model = PairTabModel(tab_file=file_path, rcut=0.3, sel=2) + dp_model = DPAtomicModel(ds, ft, type_map=type_map, resuming=True).to( + env.DEVICE + ) + wgt_model = DPZBLLinearAtomicModel( + dp_model, + zbl_model, + sw_rmin=0.1, + sw_rmax=0.25, + ) + wgt_res = [] + for dist in np.linspace(0.05, 0.3, 10): + extended_coord = torch.tensor( + [ + [ + [0.0, 0.0, 0.0], + [0.0, dist, 0.0], + ], + ] + ) + + wgt_model.forward_atomic(extended_coord, extended_atype, nlist) + + wgt_res.append(wgt_model.zbl_weight) + results = torch.stack(wgt_res).reshape(10, 2) + excepted_res = torch.tensor( + [ + [1.0, 0.0], + [1.0, 0.0], + [0.9995, 0.0], + [0.9236, 0.0], + [0.6697, 0.0], + [0.3303, 0.0], + [0.0764, 0.0], + [0.0005, 0.0], + [0.0, 0.0], + [0.0, 0.0], + ], + dtype=torch.float64, + ) + torch.testing.assert_close(results, excepted_res, rtol=0.0001, atol=0.0001) + + +class TestIntegration(unittest.TestCase, TestCaseSingleFrameWithNlist): + @patch("numpy.loadtxt") + def setUp(self, mock_loadtxt): + TestCaseSingleFrameWithNlist.setUp(self) + file_path = "dummy_path" + mock_loadtxt.return_value = np.array( + [ + [0.005, 1.0, 2.0, 3.0], + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + ] + ) + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(env.DEVICE) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + distinguish_types=ds.distinguish_types(), + ).to(env.DEVICE) + type_map = ["foo", "bar"] + dp_model = DPAtomicModel(ds, ft, type_map=type_map, resuming=True).to( + env.DEVICE + ) + zbl_model = PairTabModel(file_path, self.rcut, sum(self.sel)) + self.md0 = DPZBLLinearAtomicModel( + dp_model, + zbl_model, + sw_rmin=0.1, + sw_rmax=0.25, + ).to(env.DEVICE) + self.md1 = DPZBLLinearAtomicModel.deserialize(self.md0.serialize()).to( + env.DEVICE + ) + self.md2 = DPDPZBLLinearAtomicModel.deserialize(self.md0.serialize()) + self.md3 = ZBLModel(dp_model, zbl_model, sw_rmin=0.1, sw_rmax=0.25) + + def test_self_consistency(self): + args = [ + to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] + ] + ret0 = self.md0.forward_atomic(*args) + ret1 = self.md1.forward_atomic(*args) + ret2 = self.md2.forward_atomic(self.coord_ext, self.atype_ext, self.nlist) + np.testing.assert_allclose( + to_numpy_array(ret0["energy"]), + to_numpy_array(ret1["energy"]), + ) + + np.testing.assert_allclose( + to_numpy_array(ret0["energy"]), ret2["energy"], atol=0.001, rtol=0.001 + ) + + def test_jit(self): + torch.jit.script(self.md1) + torch.jit.script(self.md3) + + +if __name__ == "__main__": + unittest.main(warnings="ignore") diff --git a/source/tests/pt/model/test_pairtab.py b/source/tests/pt/model/test_pairtab_atomic_model.py similarity index 95% rename from source/tests/pt/model/test_pairtab.py rename to source/tests/pt/model/test_pairtab_atomic_model.py index e27e2cf2a1..23718c134a 100644 --- a/source/tests/pt/model/test_pairtab.py +++ b/source/tests/pt/model/test_pairtab_atomic_model.py @@ -7,10 +7,13 @@ import numpy as np import torch -from deepmd.dpmodel.model.pair_tab_model import PairTabModel as DPPairTabModel -from deepmd.pt.model.model.pair_tab_model import ( +from deepmd.dpmodel.model.pairtab_atomic_model import PairTabModel as DPPairTabModel +from deepmd.pt.model.model.pairtab_atomic_model import ( PairTabModel, ) +from deepmd.pt.utils.utils import ( + to_numpy_array, +) class TestPairTab(unittest.TestCase): @@ -114,9 +117,8 @@ def test_cross_deserialize(self): expected_result = self.model.forward_atomic( self.extended_coord, self.extended_atype, torch.from_numpy(self.nlist) ) - np.testing.assert_allclose( - result["energy"], expected_result["energy"], 0.0001, 0.0001 + result["energy"], to_numpy_array(expected_result["energy"]), 0.0001, 0.0001 ) @@ -235,5 +237,6 @@ def test_extrapolation_nonzero_rmax(self, mock_loadtxt) -> None: torch.testing.assert_close(results, expected_result, rtol=0.0001, atol=0.0001) - if __name__ == "__main__": - unittest.main() + +if __name__ == "__main__": + unittest.main(warnings="ignore") diff --git a/source/tests/pt/model/test_permutation.py b/source/tests/pt/model/test_permutation.py index b9724bb2af..2301b6ea10 100644 --- a/source/tests/pt/model/test_permutation.py +++ b/source/tests/pt/model/test_permutation.py @@ -12,6 +12,7 @@ ) from deepmd.pt.model.model import ( get_model, + get_zbl_model, ) from deepmd.pt.utils import ( env, @@ -45,6 +46,30 @@ "data_stat_nbatch": 20, } +model_zbl = { + "type_map": ["O", "H", "B"], + "use_srtab": "source/tests/pt/model/water/data/zbl_tab_potential/H2O_tab_potential.txt", + "smin_alpha": 0.1, + "sw_rmin": 0.2, + "sw_rmax": 1.0, + "descriptor": { + "type": "se_e2_a", + "sel": [46, 92, 4], + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [25, 50, 100], + "resnet_dt": False, + "axis_neuron": 16, + "seed": 1, + }, + "fitting_net": { + "neuron": [24, 24, 24], + "resnet_dt": True, + "seed": 1, + }, + "data_stat_nbatch": 20, +} + model_dpa2 = { "type_map": ["O", "H", "B"], "descriptor": { @@ -302,6 +327,14 @@ def setUp(self): self.model = get_model(model_params, sampled).to(env.DEVICE) +class TestEnergyModelZBL(unittest.TestCase, PermutationTest): + def setUp(self): + model_params = copy.deepcopy(model_zbl) + sampled = make_sample(model_params) + self.type_split = False + self.model = get_zbl_model(model_params, sampled).to(env.DEVICE) + + # class TestEnergyFoo(unittest.TestCase): # def test(self): # model_params = model_dpau diff --git a/source/tests/pt/model/test_rot.py b/source/tests/pt/model/test_rot.py index 7222fd6f69..982753e94f 100644 --- a/source/tests/pt/model/test_rot.py +++ b/source/tests/pt/model/test_rot.py @@ -9,6 +9,7 @@ ) from deepmd.pt.model.model import ( get_model, + get_zbl_model, ) from deepmd.pt.utils import ( env, @@ -20,6 +21,7 @@ model_dpa2, model_hybrid, model_se_e2_a, + model_zbl, ) dtype = torch.float64 @@ -177,5 +179,13 @@ def setUp(self): self.model = get_model(model_params, sampled).to(env.DEVICE) +class TestEnergyModelZBL(unittest.TestCase, RotTest): + def setUp(self): + model_params = copy.deepcopy(model_zbl) + sampled = make_sample(model_params) + self.type_split = False + self.model = get_zbl_model(model_params, sampled).to(env.DEVICE) + + if __name__ == "__main__": unittest.main() diff --git a/source/tests/pt/model/test_smooth.py b/source/tests/pt/model/test_smooth.py index 2e3bf61d10..f2f45c74aa 100644 --- a/source/tests/pt/model/test_smooth.py +++ b/source/tests/pt/model/test_smooth.py @@ -9,6 +9,7 @@ ) from deepmd.pt.model.model import ( get_model, + get_zbl_model, ) from deepmd.pt.utils import ( env, @@ -20,6 +21,7 @@ model_dpa2, model_hybrid, model_se_e2_a, + model_zbl, ) dtype = torch.float64 @@ -210,6 +212,15 @@ def setUp(self): self.epsilon, self.aprec = None, None +class TestEnergyModelZBL(unittest.TestCase, SmoothTest): + def setUp(self): + model_params = copy.deepcopy(model_zbl) + sampled = make_sample(model_params) + self.type_split = False + self.model = get_zbl_model(model_params, sampled).to(env.DEVICE) + self.epsilon, self.aprec = None, None + + # class TestEnergyFoo(unittest.TestCase): # def test(self): # model_params = model_dpau diff --git a/source/tests/pt/model/test_trans.py b/source/tests/pt/model/test_trans.py index e5d379b9ff..967d505c6d 100644 --- a/source/tests/pt/model/test_trans.py +++ b/source/tests/pt/model/test_trans.py @@ -9,6 +9,7 @@ ) from deepmd.pt.model.model import ( get_model, + get_zbl_model, ) from deepmd.pt.utils import ( env, @@ -20,6 +21,7 @@ model_dpa2, model_hybrid, model_se_e2_a, + model_zbl, ) dtype = torch.float64 @@ -133,5 +135,13 @@ def setUp(self): self.model = get_model(model_params, sampled).to(env.DEVICE) +class TestEnergyModelZBL(unittest.TestCase, TransTest): + def setUp(self): + model_params = copy.deepcopy(model_zbl) + sampled = make_sample(model_params) + self.type_split = False + self.model = get_zbl_model(model_params, sampled).to(env.DEVICE) + + if __name__ == "__main__": unittest.main() diff --git a/source/tests/pt/model/water/data/zbl_tab_potential/H2O_tab_potential.txt b/source/tests/pt/model/water/data/zbl_tab_potential/H2O_tab_potential.txt new file mode 100644 index 0000000000..66fcb8e946 --- /dev/null +++ b/source/tests/pt/model/water/data/zbl_tab_potential/H2O_tab_potential.txt @@ -0,0 +1,1000 @@ +0.0010 913709.625838 114389.26607 14320.660836 25838 114389.26607 14320.660836 +0.0020 453190.075792 56822.165078 7124.559066 75792 56822.165078 7124.559066 +0.0030 299716.609389 37635.860646 4726.059712 09389 37635.860646 4726.059712 +0.0040 223004.208152 28044.724786 3526.959232 08152 28044.724786 3526.959232 +0.0050 176995.875921 22291.632310 2807.616935 75921 22291.632310 2807.616935 +0.0060 146339.286793 18457.541826 2328.152606 86793 18457.541826 2328.152606 +0.0070 124454.877677 15720.007305 1985.760451 77677 15720.007305 1985.760451 +0.0080 108052.871443 13667.805976 1729.037583 71443 13667.805976 1729.037583 +0.0090 95305.6179694 12072.480958 1529.426853 79694 12072.480958 1529.426853 +0.0100 85116.5305655 10796.958308 1369.793979 05655 10796.958308 1369.793979 +0.0110 76787.7843454 9754.0093240 1239.235334 43454 9754.0093240 1239.235334 +0.0120 69854.1654175 8885.4816862 1130.481842 54175 8885.4816862 1130.481842 +0.0130 63993.6050636 8151.1162355 1038.501071 50636 8151.1162355 1038.501071 +0.0140 58976.0564146 7522.1565542 959.6984312 64146 7522.1565542 959.6984312 +0.0150 54632.8204564 6977.5147177 891.4378965 04564 6977.5147177 891.4378965 +0.0160 50837.3747846 6501.3748881 831.7424519 47846 6501.3748881 831.7424519 +0.0170 47492.9686820 6081.6426991 779.1002669 86820 6081.6426991 779.1002669 +0.0180 44524.3531708 5708.9115119 732.3354774 31708 5708.9115119 732.3354774 +0.0190 41872.1226283 5375.7551174 690.5197734 26283 5375.7551174 690.5197734 +0.0200 39488.7539185 5076.2326272 652.9105109 39185 5076.2326272 652.9105109 +0.0210 37335.7772003 4805.5348243 618.9065049 72003 4805.5348243 618.9065049 +0.0220 35381.7183353 4559.7269643 588.0158802 83353 4559.7269643 588.0158802 +0.0230 33600.5780582 4335.5586712 559.8323083 80582 4335.5586712 559.8323083 +0.0240 31970.6913556 4130.3213594 534.0171847 13556 4130.3213594 534.0171847 +0.0250 30473.8605947 3941.7398750 510.2860845 05947 3941.7398750 510.2860845 +0.0260 29094.6886984 3767.8891424 488.3983430 86984 3767.8891424 488.3983430 +0.0270 27820.0605059 3607.1293341 468.1489521 05059 3607.1293341 468.1489521 +0.0280 26638.7352692 3458.0549325 449.3621928 52692 3458.0549325 449.3621928 +0.0290 25541.0234620 3319.4543312 431.8865855 34620 3319.4543312 431.8865855 +0.0300 24518.5282265 3190.2775152 415.5908499 82265 3190.2775152 415.5908499 +0.0310 23563.9368637 3069.6099976 400.3606472 68637 3069.6099976 400.3606472 +0.0320 22670.8514191 2956.6516420 386.0959329 14191 2956.6516420 386.0959329 +0.0330 21833.6500715 2850.6993366 372.7087910 00715 2850.6993366 372.7087910 +0.0340 21047.3729830 2751.1327248 360.1216502 29830 2751.1327248 360.1216502 +0.0350 20307.6277175 2657.4023825 348.2658062 77175 2657.4023825 348.2658062 +0.0360 19610.5104235 2569.0199661 337.0801904 04235 2569.0199661 337.0801904 +0.0370 18952.5397978 2485.5499575 326.5103374 97978 2485.5499575 326.5103374 +0.0380 18330.6014769 2406.6027127 316.5075168 14769 2406.6027127 316.5075168 +0.0390 17741.9009829 2331.8285805 307.0279975 09829 2331.8285805 307.0279975 +0.0400 17183.9237284 2260.9129024 298.0324229 37284 2260.9129024 298.0324229 +0.0410 16654.4008745 2193.5717452 289.4852776 08745 2193.5717452 289.4852776 +0.0420 16151.2800661 2129.5482423 281.3544296 00661 2129.5482423 281.3544296 +0.0430 15672.7002509 2068.6094464 273.6107373 02509 2068.6094464 273.6107373 +0.0440 15216.9699328 2010.5436107 266.2277097 99328 2010.5436107 266.2277097 +0.0450 14782.5483242 1955.1578329 259.1812115 83242 1955.1578329 259.1812115 +0.0460 14368.0289580 1902.2760069 252.4492073 89580 1902.2760069 252.4492073 +0.0470 13972.1253902 1851.7370350 246.0115383 53902 1851.7370350 246.0115383 +0.0480 13593.6586918 1803.3932646 239.8497262 86918 1803.3932646 239.8497262 +0.0490 13231.5464708 1757.1091159 233.9468027 64708 1757.1091159 233.9468027 +0.0500 12884.7932124 1712.7598740 228.2871576 32124 1712.7598740 228.2871576 +0.0510 12552.4817558 1670.2306236 222.8564060 17558 1670.2306236 222.8564060 +0.0520 12233.7657548 1629.4153064 217.6412705 57548 1629.4153064 217.6412705 +0.0530 11927.8629910 1590.2158855 212.6294767 29910 1590.2158855 212.6294767 +0.0540 11634.0494314 1552.5416017 207.8096602 94314 1552.5416017 207.8096602 +0.0550 11351.6539336 1516.3083123 203.1712838 39336 1516.3083123 203.1712838 +0.0560 11080.0535186 1481.4378999 198.7045641 35186 1481.4378999 198.7045641 +0.0570 10818.6691413 1447.8577434 194.4004048 91413 1447.8577434 194.4004048 +0.0580 10566.9618984 1415.5002442 190.2503376 18984 1415.5002442 190.2503376 +0.0590 10324.4296227 1384.3024001 186.2464693 96227 1384.3024001 186.2464693 +0.0600 10090.6038173 1354.2054222 182.3814335 38173 1354.2054222 182.3814335 +0.0610 9865.0468917 1325.1543893 178.6483475 68917 1325.1543893 178.6483475 +0.0620 9647.3496659 1297.0979357 175.0407734 96659 1297.0979357 175.0407734 +0.0630 9437.1291115 1269.9879689 171.5526826 91115 1269.9879689 171.5526826 +0.0640 9234.0263053 1243.7794136 168.1784240 63053 1243.7794136 168.1784240 +0.0650 9037.7045714 1218.4299795 164.9126949 45714 1218.4299795 164.9126949 +0.0660 8847.8477928 1193.8999501 161.7505145 77928 1193.8999501 161.7505145 +0.0670 8664.1588738 1170.1519903 158.6871999 88738 1170.1519903 158.6871999 +0.0680 8486.3583383 1147.1509714 155.7183444 83383 1147.1509714 155.7183444 +0.0690 8314.1830501 1124.8638108 152.8397972 30501 1124.8638108 152.8397972 +0.0700 8147.3850427 1103.2593259 150.0476452 50427 1103.2593259 150.0476452 +0.0710 7985.7304489 1082.3080999 147.3381963 04489 1082.3080999 147.3381963 +0.0720 7828.9985183 1061.9823592 144.7079640 85183 1061.9823592 144.7079640 +0.0730 7676.9807165 1042.2558606 142.1536534 07165 1042.2558606 142.1536534 +0.0740 7529.4798977 1023.1037878 139.6721482 98977 1023.1037878 139.6721482 +0.0750 7386.3095424 1004.5026562 137.2604986 95424 1004.5026562 137.2604986 +0.0760 7247.2930565 986.4302250 134.9159107 930565 986.4302250 134.9159107 +0.0770 7112.2631243 968.8654164 132.6357361 631243 968.8654164 132.6357361 +0.0780 6981.0611116 951.7882409 130.4174626 611116 951.7882409 130.4174626 +0.0790 6853.5365143 935.1797284 128.2587056 365143 935.1797284 128.2587056 +0.0800 6729.5464483 919.0218641 126.1572004 464483 919.0218641 126.1572004 +0.0810 6608.9551768 903.2975297 124.1107942 551768 903.2975297 124.1107942 +0.0820 6491.6336731 887.9904484 122.1174397 336731 887.9904484 122.1174397 +0.0830 6377.4592142 873.0851342 120.1751889 592142 873.0851342 120.1751889 +0.0840 6266.3150042 858.5668449 118.2821865 150042 858.5668449 118.2821865 +0.0850 6158.0898234 844.4215378 116.4366651 898234 844.4215378 116.4366651 +0.0860 6052.6777030 830.6358295 114.6369400 777030 830.6358295 114.6369400 +0.0870 5949.9776216 817.1969572 112.8814040 776216 817.1969572 112.8814040 +0.0880 5849.8932223 804.0927442 111.1685235 932223 804.0927442 111.1685235 +0.0890 5752.3325494 791.3115660 109.4968341 325494 791.3115660 109.4968341 +0.0900 5657.2078026 778.8423203 107.8649368 078026 778.8423203 107.8649368 +0.0910 5564.4351069 766.6743978 106.2714943 351069 766.6743978 106.2714943 +0.0920 5473.9342981 754.7976551 104.7152279 342981 754.7976551 104.7152279 +0.0930 5385.6287222 743.2023904 103.1949141 287222 743.2023904 103.1949141 +0.0940 5299.4450471 731.8793190 101.7093819 450471 731.8793190 101.7093819 +0.0950 5215.3130867 720.8195518 100.2575097 130867 720.8195518 100.2575097 +0.0960 5133.1656359 710.0145745 98.8382229 1656359 710.0145745 98.8382229 +0.0970 5052.9383157 699.4562281 97.4504918 9383157 699.4562281 97.4504918 +0.0980 4974.5694279 689.1366911 96.0933285 5694279 689.1366911 96.0933285 +0.0990 4897.9998188 679.0484617 94.7657857 9998188 679.0484617 94.7657857 +0.1000 4823.1727507 669.1843423 93.4669540 1727507 669.1843423 93.4669540 +0.1010 4750.0337815 659.5374244 92.1959604 0337815 659.5374244 92.1959604 +0.1020 4678.5306510 650.1010737 90.9519663 5306510 650.1010737 90.9519663 +0.1030 4608.6131741 640.8689177 89.7341659 6131741 640.8689177 89.7341659 +0.1040 4540.2331402 631.8348323 88.5417846 2331402 631.8348323 88.5417846 +0.1050 4473.3442182 622.9929301 87.3740776 3442182 622.9929301 87.3740776 +0.1060 4407.9018671 614.3375495 86.2303284 9018671 614.3375495 86.2303284 +0.1070 4343.8632512 605.8632435 85.1098475 8632512 605.8632435 85.1098475 +0.1080 4281.1871608 597.5647703 84.0119712 1871608 597.5647703 84.0119712 +0.1090 4219.8339365 589.4370834 82.9360602 8339365 589.4370834 82.9360602 +0.1100 4159.7653981 581.4753230 81.8814988 7653981 581.4753230 81.8814988 +0.1110 4100.9447770 573.6748074 80.8476936 9447770 573.6748074 80.8476936 +0.1120 4043.3366528 566.0310249 79.8340726 3366528 566.0310249 79.8340726 +0.1130 3986.9068928 558.5396262 78.8400844 9068928 558.5396262 78.8400844 +0.1140 3931.6225949 551.1964175 77.8651968 6225949 551.1964175 77.8651968 +0.1150 3877.4520333 543.9973537 76.9088966 4520333 543.9973537 76.9088966 +0.1160 3824.3646071 536.9385314 75.9706883 3646071 536.9385314 75.9706883 +0.1170 3772.3307921 530.0161836 75.0500935 3307921 530.0161836 75.0500935 +0.1180 3721.3220939 523.2266731 74.1466504 3220939 523.2266731 74.1466504 +0.1190 3671.3110047 516.5664876 73.2599126 3110047 516.5664876 73.2599126 +0.1200 3622.2709610 510.0322343 72.3894489 2709610 510.0322343 72.3894489 +0.1210 3574.1763045 503.6206346 71.5348425 1763045 503.6206346 71.5348425 +0.1220 3527.0022444 497.3285198 70.6956904 0022444 497.3285198 70.6956904 +0.1230 3480.7248214 491.1528264 69.8716028 7248214 491.1528264 69.8716028 +0.1240 3435.3208739 485.0905919 69.0622028 3208739 485.0905919 69.0622028 +0.1250 3390.7680053 479.1389505 68.2671256 7680053 479.1389505 68.2671256 +0.1260 3347.0445536 473.2951298 67.4860180 0445536 473.2951298 67.4860180 +0.1270 3304.1295618 467.5564462 66.7185382 1295618 467.5564462 66.7185382 +0.1280 3262.0027498 461.9203022 65.9643553 0027498 461.9203022 65.9643553 +0.1290 3220.6444879 456.3841826 65.2231486 6444879 456.3841826 65.2231486 +0.1300 3180.0357713 450.9456517 64.4946075 0357713 450.9456517 64.4946075 +0.1310 3140.1581958 445.6023495 63.7784310 1581958 445.6023495 63.7784310 +0.1320 3100.9939349 440.3519898 63.0743273 9939349 440.3519898 63.0743273 +0.1330 3062.5257173 435.1923563 62.3820136 5257173 435.1923563 62.3820136 +0.1340 3024.7368060 430.1213011 61.7012156 7368060 430.1213011 61.7012156 +0.1350 2987.6109783 425.1367411 61.0316673 6109783 425.1367411 61.0316673 +0.1360 2951.1325064 420.2366563 60.3731105 1325064 420.2366563 60.3731105 +0.1370 2915.2861387 415.4190873 59.7252948 2861387 415.4190873 59.7252948 +0.1380 2880.0570829 410.6821328 59.0879771 0570829 410.6821328 59.0879771 +0.1390 2845.4309885 406.0239478 58.4609214 4309885 406.0239478 58.4609214 +0.1400 2811.3939311 401.4427415 57.8438986 3939311 401.4427415 57.8438986 +0.1410 2777.9323966 396.9367752 57.2366861 9323966 396.9367752 57.2366861 +0.1420 2745.0332671 392.5043608 56.6390678 0332671 392.5043608 56.6390678 +0.1430 2712.6838060 388.1438584 56.0508336 6838060 388.1438584 56.0508336 +0.1440 2680.8716450 383.8536753 55.4717796 8716450 383.8536753 55.4717796 +0.1450 2649.5847710 379.6322639 54.9017073 5847710 379.6322639 54.9017073 +0.1460 2618.8115134 375.4781203 54.3404239 8115134 375.4781203 54.3404239 +0.1470 2588.5405327 371.3897825 53.7877420 5405327 371.3897825 53.7877420 +0.1480 2558.7608086 367.3658297 53.2434792 7608086 367.3658297 53.2434792 +0.1490 2529.4616294 363.4048800 52.7074581 4616294 363.4048800 52.7074581 +0.1500 2500.6325811 359.5055896 52.1795062 6325811 359.5055896 52.1795062 +0.1510 2472.2635377 355.6666516 51.6594557 2635377 355.6666516 51.6594557 +0.1520 2444.3446512 351.8867943 51.1471432 3446512 351.8867943 51.1471432 +0.1530 2416.8663423 348.1647806 50.6424097 8663423 348.1647806 50.6424097 +0.1540 2389.8192919 344.4994064 50.1451003 8192919 344.4994064 50.1451003 +0.1550 2363.1944319 340.8894997 49.6550644 1944319 340.8894997 49.6550644 +0.1560 2336.9829372 337.3339196 49.1721551 9829372 337.3339196 49.1721551 +0.1570 2311.1762181 333.8315552 48.6962295 1762181 333.8315552 48.6962295 +0.1580 2285.7659120 330.3813249 48.2271484 7659120 330.3813249 48.2271484 +0.1590 2260.7438767 326.9821749 47.7647759 7438767 326.9821749 47.7647759 +0.1600 2236.1021827 323.6330788 47.3089799 1021827 323.6330788 47.3089799 +0.1610 2211.8331072 320.3330368 46.8596315 8331072 320.3330368 46.8596315 +0.1620 2187.9291268 317.0810744 46.4166050 9291268 317.0810744 46.4166050 +0.1630 2164.3829117 313.8762419 45.9797781 3829117 313.8762419 45.9797781 +0.1640 2141.1873194 310.7176137 45.5490312 1873194 310.7176137 45.5490312 +0.1650 2118.3353890 307.6042874 45.1242479 3353890 307.6042874 45.1242479 +0.1660 2095.8203354 304.5353832 44.7053147 8203354 304.5353832 44.7053147 +0.1670 2073.6355442 301.5100431 44.2921206 6355442 301.5100431 44.2921206 +0.1680 2051.7745660 298.5274302 43.8845577 7745660 298.5274302 43.8845577 +0.1690 2030.2311119 295.5867284 43.4825203 2311119 295.5867284 43.4825203 +0.1700 2008.9990482 292.6871414 43.0859057 9990482 292.6871414 43.0859057 +0.1710 1988.0723922 289.8278921 42.6946132 0723922 289.8278921 42.6946132 +0.1720 1967.4453070 287.0082225 42.3085448 4453070 287.0082225 42.3085448 +0.1730 1947.1120979 284.2273926 41.9276048 1120979 284.2273926 41.9276048 +0.1740 1927.0672078 281.4846800 41.5516997 0672078 281.4846800 41.5516997 +0.1750 1907.3052129 278.7793797 41.1807380 3052129 278.7793797 41.1807380 +0.1760 1887.8208195 276.1108033 40.8146308 8208195 276.1108033 40.8146308 +0.1770 1868.6088596 273.4782784 40.4532907 6088596 273.4782784 40.4532907 +0.1780 1849.6642873 270.8811486 40.0966328 6642873 270.8811486 40.0966328 +0.1790 1830.9821758 268.3187725 39.7445739 9821758 268.3187725 39.7445739 +0.1800 1812.5577133 265.7905239 39.3970327 5577133 265.7905239 39.3970327 +0.1810 1794.3862002 263.2957906 39.0539298 3862002 263.2957906 39.0539298 +0.1820 1776.4630458 260.8339748 38.7151877 4630458 260.8339748 38.7151877 +0.1830 1758.7837651 258.4044920 38.3807303 7837651 258.4044920 38.3807303 +0.1840 1741.3439759 256.0067712 38.0504836 3439759 256.0067712 38.0504836 +0.1850 1724.1393960 253.6402542 37.7243750 1393960 253.6402542 37.7243750 +0.1860 1707.1658404 251.3043952 37.4023336 1658404 251.3043952 37.4023336 +0.1870 1690.4192185 248.9986608 37.0842900 4192185 248.9986608 37.0842900 +0.1880 1673.8955316 246.7225293 36.7701764 8955316 246.7225293 36.7701764 +0.1890 1657.5908704 244.4754905 36.4599264 5908704 244.4754905 36.4599264 +0.1900 1641.5014126 242.2570456 36.1534752 5014126 242.2570456 36.1534752 +0.1910 1625.6234204 240.0667066 35.8507590 6234204 240.0667066 35.8507590 +0.1920 1609.9532382 237.9039960 35.5517159 9532382 237.9039960 35.5517159 +0.1930 1594.4872906 235.7684470 35.2562850 4872906 235.7684470 35.2562850 +0.1940 1579.2220803 233.6596024 34.9644066 2220803 233.6596024 34.9644066 +0.1950 1564.1541856 231.5770153 34.6760226 1541856 231.5770153 34.6760226 +0.1960 1549.2802589 229.5202480 34.3910759 2802589 229.5202480 34.3910759 +0.1970 1534.5970244 227.4888722 34.1095106 5970244 227.4888722 34.1095106 +0.1980 1520.1012763 225.4824686 33.8312720 1012763 225.4824686 33.8312720 +0.1990 1505.7898772 223.5006269 33.5563066 7898772 223.5006269 33.5563066 +0.2000 1491.6597561 221.5429453 33.2845619 6597561 221.5429453 33.2845619 +0.2010 1477.7079067 219.6090303 33.0159865 7079067 219.6090303 33.0159865 +0.2020 1463.9313857 217.6984967 32.7505301 9313857 217.6984967 32.7505301 +0.2030 1450.3273114 215.8109671 32.4881435 3273114 215.8109671 32.4881435 +0.2040 1436.8928620 213.9460720 32.2287782 8928620 213.9460720 32.2287782 +0.2050 1423.6252739 212.1034495 31.9723870 6252739 212.1034495 31.9723870 +0.2060 1410.5218407 210.2827449 31.7189234 5218407 210.2827449 31.7189234 +0.2070 1397.5799113 208.4836109 31.4683421 5799113 208.4836109 31.4683421 +0.2080 1384.7968886 206.7057069 31.2205985 7968886 206.7057069 31.2205985 +0.2090 1372.1702286 204.9486995 30.9756490 1702286 204.9486995 30.9756490 +0.2100 1359.6974383 203.2122618 30.7334506 6974383 203.2122618 30.7334506 +0.2110 1347.3760753 201.4960735 30.4939616 3760753 201.4960735 30.4939616 +0.2120 1335.2037458 199.7998204 30.2571406 2037458 199.7998204 30.2571406 +0.2130 1323.1781040 198.1231947 30.0229474 1781040 198.1231947 30.0229474 +0.2140 1311.2968506 196.4658948 29.7913425 2968506 196.4658948 29.7913425 +0.2150 1299.5577318 194.8276247 29.5622869 5577318 194.8276247 29.5622869 +0.2160 1287.9585380 193.2080943 29.3357428 9585380 193.2080943 29.3357428 +0.2170 1276.4971033 191.6070192 29.1116726 4971033 191.6070192 29.1116726 +0.2180 1265.1713036 190.0241204 28.8900399 1713036 190.0241204 28.8900399 +0.2190 1253.9790564 188.4591242 28.6708087 9790564 188.4591242 28.6708087 +0.2200 1242.9183194 186.9117623 28.4539438 9183194 186.9117623 28.4539438 +0.2210 1231.9870896 185.3817715 28.2394106 9870896 185.3817715 28.2394106 +0.2220 1221.1834024 183.8688934 28.0271751 1834024 183.8688934 28.0271751 +0.2230 1210.5053309 182.3728746 27.8172040 5053309 182.3728746 27.8172040 +0.2240 1199.9509845 180.8934666 27.6094647 9509845 180.8934666 27.6094647 +0.2250 1189.5185088 179.4304255 27.4039251 5185088 179.4304255 27.4039251 +0.2260 1179.2060841 177.9835117 27.2005537 2060841 177.9835117 27.2005537 +0.2270 1169.0119251 176.5524904 26.9993196 0119251 176.5524904 26.9993196 +0.2280 1158.9342796 175.1371309 26.8001924 9342796 175.1371309 26.8001924 +0.2290 1148.9714282 173.7372069 26.6031423 9714282 173.7372069 26.6031423 +0.2300 1139.1216834 172.3524963 26.4081402 1216834 172.3524963 26.4081402 +0.2310 1129.3833889 170.9827808 26.2151571 3833889 170.9827808 26.2151571 +0.2320 1119.7549187 169.6278463 26.0241650 7549187 169.6278463 26.0241650 +0.2330 1110.2346768 168.2874825 25.8351362 2346768 168.2874825 25.8351362 +0.2340 1100.8210961 166.9614829 25.6480433 8210961 166.9614829 25.6480433 +0.2350 1091.5126382 165.6496448 25.4628597 5126382 165.6496448 25.4628597 +0.2360 1082.3077924 164.3517691 25.2795591 3077924 164.3517691 25.2795591 +0.2370 1073.2050753 163.0676601 25.0981157 2050753 163.0676601 25.0981157 +0.2380 1064.2030300 161.7971257 24.9185042 2030300 161.7971257 24.9185042 +0.2390 1055.3002259 160.5399772 24.7406996 3002259 160.5399772 24.7406996 +0.2400 1046.4952577 159.2960293 24.5646775 4952577 159.2960293 24.5646775 +0.2410 1037.7867450 158.0650998 24.3904138 7867450 158.0650998 24.3904138 +0.2420 1029.1733319 156.8470097 24.2178849 1733319 156.8470097 24.2178849 +0.2430 1020.6536864 155.6415833 24.0470676 6536864 155.6415833 24.0470676 +0.2440 1012.2264998 154.4486478 23.8779390 2264998 154.4486478 23.8779390 +0.2450 1003.8904862 153.2680334 23.7104767 8904862 153.2680334 23.7104767 +0.2460 995.6443822 152.0995732 23.5446586 6443822 152.0995732 23.5446586 +0.2470 987.4869462 150.9431032 23.3804631 4869462 150.9431032 23.3804631 +0.2480 979.4169580 149.7984622 23.2178688 4169580 149.7984622 23.2178688 +0.2490 971.4332186 148.6654918 23.0568547 4332186 148.6654918 23.0568547 +0.2500 963.5345494 147.5440362 22.8974003 5345494 147.5440362 22.8974003 +0.2510 955.7197919 146.4339423 22.7394853 7197919 146.4339423 22.7394853 +0.2520 947.9878073 145.3350596 22.5830897 9878073 145.3350596 22.5830897 +0.2530 940.3374762 144.2472399 22.4281939 3374762 144.2472399 22.4281939 +0.2540 932.7676981 143.1703377 22.2747787 7676981 143.1703377 22.2747787 +0.2550 925.2773908 142.1042099 22.1228251 2773908 142.1042099 22.1228251 +0.2560 917.8654906 141.0487157 21.9723144 8654906 141.0487157 21.9723144 +0.2570 910.5309511 140.0037168 21.8232283 5309511 140.0037168 21.8232283 +0.2580 903.2727438 138.9690768 21.6755488 2727438 138.9690768 21.6755488 +0.2590 896.0898568 137.9446619 21.5292580 0898568 137.9446619 21.5292580 +0.2600 888.9812952 136.9303404 21.3843385 9812952 136.9303404 21.3843385 +0.2610 881.9460805 135.9259827 21.2407731 9460805 135.9259827 21.2407731 +0.2620 874.9832499 134.9314613 21.0985449 9832499 134.9314613 21.0985449 +0.2630 868.0918568 133.9466506 20.9576372 0918568 133.9466506 20.9576372 +0.2640 861.2709696 132.9714274 20.8180337 2709696 132.9714274 20.8180337 +0.2650 854.5196721 132.0056702 20.6797182 5196721 132.0056702 20.6797182 +0.2660 847.8370627 131.0492595 20.5426748 8370627 131.0492595 20.5426748 +0.2670 841.2222545 130.1020778 20.4068880 2222545 130.1020778 20.4068880 +0.2680 834.6743747 129.1640092 20.2723424 6743747 129.1640092 20.2723424 +0.2690 828.1925646 128.2349399 20.1390228 1925646 128.2349399 20.1390228 +0.2700 821.7759790 127.3147578 20.0069143 7759790 127.3147578 20.0069143 +0.2710 815.4237863 126.4033526 19.8760022 4237863 126.4033526 19.8760022 +0.2720 809.1351680 125.5006157 19.7462722 1351680 125.5006157 19.7462722 +0.2730 802.9093184 124.6064402 19.6177099 9093184 124.6064402 19.6177099 +0.2740 796.7454448 123.7207207 19.4903015 7454448 123.7207207 19.4903015 +0.2750 790.6427665 122.8433537 19.3640330 6427665 122.8433537 19.3640330 +0.2760 784.6005152 121.9742372 19.2388909 6005152 121.9742372 19.2388909 +0.2770 778.6179347 121.1132707 19.1148619 6179347 121.1132707 19.1148619 +0.2780 772.6942802 120.2603553 18.9919327 6942802 120.2603553 18.9919327 +0.2790 766.8288187 119.4153936 18.8700905 8288187 119.4153936 18.8700905 +0.2800 761.0208283 118.5782897 18.7493224 0208283 118.5782897 18.7493224 +0.2810 755.2695984 117.7489490 18.6296158 2695984 117.7489490 18.6296158 +0.2820 749.5744291 116.9272787 18.5109583 5744291 116.9272787 18.5109583 +0.2830 743.9346311 116.1131870 18.3933378 9346311 116.1131870 18.3933378 +0.2840 738.3495259 115.3065837 18.2767421 3495259 115.3065837 18.2767421 +0.2850 732.8184450 114.5073800 18.1611595 8184450 114.5073800 18.1611595 +0.2860 727.3407300 113.7154881 18.0465783 3407300 113.7154881 18.0465783 +0.2870 721.9157327 112.9308219 17.9329869 9157327 112.9308219 17.9329869 +0.2880 716.5428142 112.1532963 17.8203740 5428142 112.1532963 17.8203740 +0.2890 711.2213456 111.3828276 17.7087284 2213456 111.3828276 17.7087284 +0.2900 705.9507071 110.6193334 17.5980392 9507071 110.6193334 17.5980392 +0.2910 700.7302882 109.8627322 17.4882954 7302882 109.8627322 17.4882954 +0.2920 695.5594874 109.1129441 17.3794864 5594874 109.1129441 17.3794864 +0.2930 690.4377121 108.3698900 17.2716016 4377121 108.3698900 17.2716016 +0.2940 685.3643785 107.6334922 17.1646306 3643785 107.6334922 17.1646306 +0.2950 680.3389114 106.9036741 17.0585633 3389114 106.9036741 17.0585633 +0.2960 675.3607438 106.1803600 16.9533894 3607438 106.1803600 16.9533894 +0.2970 670.4293171 105.4634755 16.8490991 4293171 105.4634755 16.8490991 +0.2980 665.5440809 104.7529472 16.7456826 5440809 104.7529472 16.7456826 +0.2990 660.7044927 104.0487027 16.6431300 7044927 104.0487027 16.6431300 +0.3000 655.9100179 103.3506708 16.5414320 9100179 103.3506708 16.5414320 +0.3010 651.1601295 102.6587812 16.4405792 1601295 102.6587812 16.4405792 +0.3020 646.4543081 101.9729644 16.3405621 4543081 101.9729644 16.3405621 +0.3030 641.7920419 101.2931523 16.2413718 7920419 101.2931523 16.2413718 +0.3040 637.1728262 100.6192774 16.1429992 1728262 100.6192774 16.1429992 +0.3050 632.5961636 99.9512734 16.0454353 .5961636 99.9512734 16.0454353 +0.3060 628.0615636 99.2890746 15.9486715 .0615636 99.2890746 15.9486715 +0.3070 623.5685430 98.6326167 15.8526991 .5685430 98.6326167 15.8526991 +0.3080 619.1166250 97.9818358 15.7575095 .1166250 97.9818358 15.7575095 +0.3090 614.7053397 97.3366692 15.6630943 .7053397 97.3366692 15.6630943 +0.3100 610.3342239 96.6970550 15.5694453 .3342239 96.6970550 15.5694453 +0.3110 606.0028208 96.0629321 15.4765543 .0028208 96.0629321 15.4765543 +0.3120 601.7106798 95.4342402 15.3844132 .7106798 95.4342402 15.3844132 +0.3130 597.4573568 94.8109200 15.2930139 .4573568 94.8109200 15.2930139 +0.3140 593.2424137 94.1929129 15.2023488 .2424137 94.1929129 15.2023488 +0.3150 589.0654185 93.5801610 15.1124099 .0654185 93.5801610 15.1124099 +0.3160 584.9259453 92.9726074 15.0231897 .9259453 92.9726074 15.0231897 +0.3170 580.8235739 92.3701958 14.9346807 .8235739 92.3701958 14.9346807 +0.3180 576.7578898 91.7728708 14.8468753 .7578898 91.7728708 14.8468753 +0.3190 572.7284843 91.1805775 14.7597662 .7284843 91.1805775 14.7597662 +0.3200 568.7349543 90.5932620 14.6733463 .7349543 90.5932620 14.6733463 +0.3210 564.7769020 90.0108711 14.5876082 .7769020 90.0108711 14.5876082 +0.3220 560.8539351 89.4333521 14.5025451 .8539351 89.4333521 14.5025451 +0.3230 556.9656667 88.8606532 14.4181498 .9656667 88.8606532 14.4181498 +0.3240 553.1117150 88.2927232 14.3344156 .1117150 88.2927232 14.3344156 +0.3250 549.2917033 87.7295116 14.2513356 .2917033 87.7295116 14.2513356 +0.3260 545.5052600 87.1709686 14.1689032 .5052600 87.1709686 14.1689032 +0.3270 541.7520185 86.6170450 14.0871117 .7520185 86.6170450 14.0871117 +0.3280 538.0316170 86.0676922 14.0059546 .0316170 86.0676922 14.0059546 +0.3290 534.3436986 85.5228624 13.9254255 .3436986 85.5228624 13.9254255 +0.3300 530.6879111 84.9825082 13.8455180 .6879111 84.9825082 13.8455180 +0.3310 527.0639069 84.4465831 13.7662258 .0639069 84.4465831 13.7662258 +0.3320 523.4713431 83.9150409 13.6875428 .4713431 83.9150409 13.6875428 +0.3330 519.9098812 83.3878361 13.6094628 .9098812 83.3878361 13.6094628 +0.3340 516.3791872 82.8649240 13.5319798 .3791872 82.8649240 13.5319798 +0.3350 512.8789315 82.3462602 13.4550878 .8789315 82.3462602 13.4550878 +0.3360 509.4087887 81.8318010 13.3787809 .4087887 81.8318010 13.3787809 +0.3370 505.9684377 81.3215032 13.3030534 .9684377 81.3215032 13.3030534 +0.3380 502.5575616 80.8153242 13.2278994 .5575616 80.8153242 13.2278994 +0.3390 499.1758475 80.3132218 13.1533134 .1758475 80.3132218 13.1533134 +0.3400 495.8229866 79.8151546 13.0792897 .8229866 79.8151546 13.0792897 +0.3410 492.4986741 79.3210816 13.0058227 .4986741 79.3210816 13.0058227 +0.3420 489.2026091 78.8309622 12.9329071 .2026091 78.8309622 12.9329071 +0.3430 485.9344945 78.3447564 12.8605374 .9344945 78.3447564 12.8605374 +0.3440 482.6940372 77.8624247 12.7887084 .6940372 77.8624247 12.7887084 +0.3450 479.4809474 77.3839282 12.7174147 .4809474 77.3839282 12.7174147 +0.3460 476.2949395 76.9092283 12.6466511 .2949395 76.9092283 12.6466511 +0.3470 473.1357312 76.4382870 12.5764126 .1357312 76.4382870 12.5764126 +0.3480 470.0030438 75.9710667 12.5066941 .0030438 75.9710667 12.5066941 +0.3490 466.8966023 75.5075304 12.4374905 .8966023 75.5075304 12.4374905 +0.3500 463.8161349 75.0476413 12.3687969 .8161349 75.0476413 12.3687969 +0.3510 460.7613734 74.5913633 12.3006084 .7613734 74.5913633 12.3006084 +0.3520 457.7320529 74.1386607 12.2329203 .7320529 74.1386607 12.2329203 +0.3530 454.7279118 73.6894981 12.1657276 .7279118 73.6894981 12.1657276 +0.3540 451.7486918 73.2438407 12.0990258 .7486918 73.2438407 12.0990258 +0.3550 448.7941378 72.8016540 12.0328101 .7941378 72.8016540 12.0328101 +0.3560 445.8639978 72.3629040 11.9670761 .8639978 72.3629040 11.9670761 +0.3570 442.9580230 71.9275570 11.9018190 .9580230 71.9275570 11.9018190 +0.3580 440.0759676 71.4955799 11.8370344 .0759676 71.4955799 11.8370344 +0.3590 437.2175888 71.0669398 11.7727180 .2175888 71.0669398 11.7727180 +0.3600 434.3826470 70.6416043 11.7088653 .3826470 70.6416043 11.7088653 +0.3610 431.5709052 70.2195415 11.6454719 .5709052 70.2195415 11.6454719 +0.3620 428.7821296 69.8007195 11.5825337 .7821296 69.8007195 11.5825337 +0.3630 426.0160891 69.3851072 11.5200463 .0160891 69.3851072 11.5200463 +0.3640 423.2725553 68.9726737 11.4580056 .2725553 68.9726737 11.4580056 +0.3650 420.5513029 68.5633884 11.3964074 .5513029 68.5633884 11.3964074 +0.3660 417.8521090 68.1572211 11.3352478 .8521090 68.1572211 11.3352478 +0.3670 415.1747536 67.7541421 11.2745225 .1747536 67.7541421 11.2745225 +0.3680 412.5190192 67.3541219 11.2142277 .5190192 67.3541219 11.2142277 +0.3690 409.8846910 66.9571314 11.1543594 .8846910 66.9571314 11.1543594 +0.3700 407.2715568 66.5631417 11.0949137 .2715568 66.5631417 11.0949137 +0.3710 404.6794069 66.1721246 11.0358867 .6794069 66.1721246 11.0358867 +0.3720 402.1080341 65.7840517 10.9772747 .1080341 65.7840517 10.9772747 +0.3730 399.5572337 65.3988954 10.9190739 .5572337 65.3988954 10.9190739 +0.3740 397.0268033 65.0166283 10.8612804 .0268033 65.0166283 10.8612804 +0.3750 394.5165432 64.6372231 10.8038908 .5165432 64.6372231 10.8038908 +0.3760 392.0262556 64.2606530 10.7469013 .0262556 64.2606530 10.7469013 +0.3770 389.5557456 63.8868916 10.6903083 .5557456 63.8868916 10.6903083 +0.3780 387.1048200 63.5159126 10.6341083 .1048200 63.5159126 10.6341083 +0.3790 384.6732884 63.1476901 10.5782978 .6732884 63.1476901 10.5782978 +0.3800 382.2609623 62.7821985 10.5228732 .2609623 62.7821985 10.5228732 +0.3810 379.8676555 62.4194124 10.4678312 .8676555 62.4194124 10.4678312 +0.3820 377.4931840 62.0593069 10.4131683 .4931840 62.0593069 10.4131683 +0.3830 375.1373659 61.7018572 10.3588812 .1373659 61.7018572 10.3588812 +0.3840 372.8000214 61.3470387 10.3049666 .8000214 61.3470387 10.3049666 +0.3850 370.4809729 60.9948274 10.2514211 .4809729 60.9948274 10.2514211 +0.3860 368.1800447 60.6451993 10.1982415 .1800447 60.6451993 10.1982415 +0.3870 365.8970633 60.2981307 10.1454247 .8970633 60.2981307 10.1454247 +0.3880 363.6318570 59.9535983 10.0929674 .6318570 59.9535983 10.0929674 +0.3890 361.3842561 59.6115790 10.0408664 .3842561 59.6115790 10.0408664 +0.3900 359.1540931 59.2720498 9.9891187 9.1540931 59.2720498 9.9891187 +0.3910 356.9412021 58.9349881 9.9377213 6.9412021 58.9349881 9.9377213 +0.3920 354.7454193 58.6003717 9.8866709 4.7454193 58.6003717 9.8866709 +0.3930 352.5665826 58.2681784 9.8359647 2.5665826 58.2681784 9.8359647 +0.3940 350.4045319 57.9383862 9.7855997 0.4045319 57.9383862 9.7855997 +0.3950 348.2591089 57.6109737 9.7355729 8.2591089 57.6109737 9.7355729 +0.3960 346.1301569 57.2859194 9.6858814 6.1301569 57.2859194 9.6858814 +0.3970 344.0175213 56.9632022 9.6365223 4.0175213 56.9632022 9.6365223 +0.3980 341.9210489 56.6428011 9.5874928 1.9210489 56.6428011 9.5874928 +0.3990 339.8405885 56.3246954 9.5387901 9.8405885 56.3246954 9.5387901 +0.4000 337.7759903 56.0088648 9.4904113 7.7759903 56.0088648 9.4904113 +0.4010 335.7271066 55.6952889 9.4423537 5.7271066 55.6952889 9.4423537 +0.4020 333.6937909 55.3839477 9.3946146 3.6937909 55.3839477 9.3946146 +0.4030 331.6758987 55.0748214 9.3471913 1.6758987 55.0748214 9.3471913 +0.4040 329.6732868 54.7678904 9.3000811 9.6732868 54.7678904 9.3000811 +0.4050 327.6858138 54.4631355 9.2532814 7.6858138 54.4631355 9.2532814 +0.4060 325.7133399 54.1605373 9.2067894 5.7133399 54.1605373 9.2067894 +0.4070 323.7557266 53.8600769 9.1606028 3.7557266 53.8600769 9.1606028 +0.4080 321.8128372 53.5617355 9.1147188 1.8128372 53.5617355 9.1147188 +0.4090 319.8845364 53.2654947 9.0691350 9.8845364 53.2654947 9.0691350 +0.4100 317.9706903 52.9713360 9.0238489 7.9706903 52.9713360 9.0238489 +0.4110 316.0711666 52.6792413 8.9788579 6.0711666 52.6792413 8.9788579 +0.4120 314.1858343 52.3891926 8.9341596 4.1858343 52.3891926 8.9341596 +0.4130 312.3145641 52.1011721 8.8897516 2.3145641 52.1011721 8.8897516 +0.4140 310.4572279 51.8151622 8.8456315 0.4572279 51.8151622 8.8456315 +0.4150 308.6136989 51.5311456 8.8017969 8.6136989 51.5311456 8.8017969 +0.4160 306.7838520 51.2491049 8.7582454 6.7838520 51.2491049 8.7582454 +0.4170 304.9675631 50.9690232 8.7149747 4.9675631 50.9690232 8.7149747 +0.4180 303.1647097 50.6908836 8.6719826 3.1647097 50.6908836 8.6719826 +0.4190 301.3751706 50.4146694 8.6292666 1.3751706 50.4146694 8.6292666 +0.4200 299.5988257 50.1403641 8.5868246 9.5988257 50.1403641 8.5868246 +0.4210 297.8355565 49.8679514 8.5446543 7.8355565 49.8679514 8.5446543 +0.4220 296.0852455 49.5974150 8.5027536 6.0852455 49.5974150 8.5027536 +0.4230 294.3477765 49.3287390 8.4611201 4.3477765 49.3287390 8.4611201 +0.4240 292.6230348 49.0619075 8.4197518 2.6230348 49.0619075 8.4197518 +0.4250 290.9109067 48.7969049 8.3786464 0.9109067 48.7969049 8.3786464 +0.4260 289.2112796 48.5337157 8.3378020 9.2112796 48.5337157 8.3378020 +0.4270 287.5240424 48.2723245 8.2972163 7.5240424 48.2723245 8.2972163 +0.4280 285.8490849 48.0127161 8.2568873 5.8490849 48.0127161 8.2568873 +0.4290 284.1862984 47.7548755 8.2168129 4.1862984 47.7548755 8.2168129 +0.4300 282.5355749 47.4987878 8.1769910 2.5355749 47.4987878 8.1769910 +0.4310 280.8968079 47.2444382 8.1374197 0.8968079 47.2444382 8.1374197 +0.4320 279.2698919 46.9918123 8.0980970 9.2698919 46.9918123 8.0980970 +0.4330 277.6547226 46.7408954 8.0590208 7.6547226 46.7408954 8.0590208 +0.4340 276.0511966 46.4916735 8.0201893 6.0511966 46.4916735 8.0201893 +0.4350 274.4592117 46.2441323 7.9816004 4.4592117 46.2441323 7.9816004 +0.4360 272.8786668 45.9982578 7.9432522 2.8786668 45.9982578 7.9432522 +0.4370 271.3094618 45.7540362 7.9051429 1.3094618 45.7540362 7.9051429 +0.4380 269.7514977 45.5114537 7.8672705 9.7514977 45.5114537 7.8672705 +0.4390 268.2046764 45.2704969 7.8296332 8.2046764 45.2704969 7.8296332 +0.4400 266.6689010 45.0311521 7.7922291 6.6689010 45.0311521 7.7922291 +0.4410 265.1440755 44.7934062 7.7550565 5.1440755 44.7934062 7.7550565 +0.4420 263.6301049 44.5572460 7.7181134 3.6301049 44.5572460 7.7181134 +0.4430 262.1268953 44.3226583 7.6813981 2.1268953 44.3226583 7.6813981 +0.4440 260.6343534 44.0896304 7.6449088 0.6343534 44.0896304 7.6449088 +0.4450 259.1523873 43.8581493 7.6086438 9.1523873 43.8581493 7.6086438 +0.4460 257.6809058 43.6282025 7.5726013 7.6809058 43.6282025 7.5726013 +0.4470 256.2198188 43.3997775 7.5367796 6.2198188 43.3997775 7.5367796 +0.4480 254.7690369 43.1728617 7.5011769 4.7690369 43.1728617 7.5011769 +0.4490 253.3284718 42.9474430 7.4657916 3.3284718 42.9474430 7.4657916 +0.4500 251.8980359 42.7235091 7.4306220 1.8980359 42.7235091 7.4306220 +0.4510 250.4776428 42.5010480 7.3956665 0.4776428 42.5010480 7.3956665 +0.4520 249.0672067 42.2800478 7.3609233 9.0672067 42.2800478 7.3609233 +0.4530 247.6666428 42.0604967 7.3263909 7.6666428 42.0604967 7.3263909 +0.4540 246.2758672 41.8423830 7.2920677 6.2758672 41.8423830 7.2920677 +0.4550 244.8947966 41.6256950 7.2579520 4.8947966 41.6256950 7.2579520 +0.4560 243.5233489 41.4104214 7.2240422 3.5233489 41.4104214 7.2240422 +0.4570 242.1614425 41.1965508 7.1903369 2.1614425 41.1965508 7.1903369 +0.4580 240.8089969 40.9840718 7.1568344 0.8089969 40.9840718 7.1568344 +0.4590 239.4659322 40.7729735 7.1235332 9.4659322 40.7729735 7.1235332 +0.4600 238.1321693 40.5632447 7.0904317 8.1321693 40.5632447 7.0904317 +0.4610 236.8076301 40.3548745 7.0575286 6.8076301 40.3548745 7.0575286 +0.4620 235.4922372 40.1478521 7.0248222 5.4922372 40.1478521 7.0248222 +0.4630 234.1859137 39.9421668 6.9923110 4.1859137 39.9421668 6.9923110 +0.4640 232.8885838 39.7378079 6.9599937 2.8885838 39.7378079 6.9599937 +0.4650 231.6001723 39.5347651 6.9278688 1.6001723 39.5347651 6.9278688 +0.4660 230.3206049 39.3330277 6.8959348 0.3206049 39.3330277 6.8959348 +0.4670 229.0498078 39.1325857 6.8641902 9.0498078 39.1325857 6.8641902 +0.4680 227.7877080 38.9334286 6.8326338 7.7877080 38.9334286 6.8326338 +0.4690 226.5342334 38.7355464 6.8012640 6.5342334 38.7355464 6.8012640 +0.4700 225.2893125 38.5389292 6.7700794 5.2893125 38.5389292 6.7700794 +0.4710 224.0528743 38.3435668 6.7390788 4.0528743 38.3435668 6.7390788 +0.4720 222.8248488 38.1494496 6.7082608 2.8248488 38.1494496 6.7082608 +0.4730 221.6051665 37.9565678 6.6776239 1.6051665 37.9565678 6.6776239 +0.4740 220.3937588 37.7649118 6.6471669 0.3937588 37.7649118 6.6471669 +0.4750 219.1905574 37.5744719 6.6168884 9.1905574 37.5744719 6.6168884 +0.4760 217.9954950 37.3852387 6.5867871 7.9954950 37.3852387 6.5867871 +0.4770 216.8085049 37.1972029 6.5568617 6.8085049 37.1972029 6.5568617 +0.4780 215.6295208 37.0103551 6.5271109 5.6295208 37.0103551 6.5271109 +0.4790 214.4584774 36.8246861 6.4975335 4.4584774 36.8246861 6.4975335 +0.4800 213.2953097 36.6401869 6.4681281 3.2953097 36.6401869 6.4681281 +0.4810 212.1399536 36.4568483 6.4388935 2.1399536 36.4568483 6.4388935 +0.4820 210.9923455 36.2746615 6.4098286 0.9923455 36.2746615 6.4098286 +0.4830 209.8524225 36.0936176 6.3809319 9.8524225 36.0936176 6.3809319 +0.4840 208.7201220 35.9137077 6.3522023 8.7201220 35.9137077 6.3522023 +0.4850 207.5953824 35.7349232 6.3236387 7.5953824 35.7349232 6.3236387 +0.4860 206.4781425 35.5572555 6.2952397 6.4781425 35.5572555 6.2952397 +0.4870 205.3683417 35.3806959 6.2670043 5.3683417 35.3806959 6.2670043 +0.4880 204.2659200 35.2052361 6.2389312 4.2659200 35.2052361 6.2389312 +0.4890 203.1708179 35.0308677 6.2110192 3.1708179 35.0308677 6.2110192 +0.4900 202.0829765 34.8575823 6.1832672 2.0829765 34.8575823 6.1832672 +0.4910 201.0023376 34.6853717 6.1556741 1.0023376 34.6853717 6.1556741 +0.4920 199.9288434 34.5142278 6.1282387 9.9288434 34.5142278 6.1282387 +0.4930 198.8624366 34.3441424 6.1009599 8.8624366 34.3441424 6.1009599 +0.4940 197.8030607 34.1751076 6.0738365 7.8030607 34.1751076 6.0738365 +0.4950 196.7506594 34.0071154 6.0468675 6.7506594 34.0071154 6.0468675 +0.4960 195.7051773 33.8401579 6.0200517 5.7051773 33.8401579 6.0200517 +0.4970 194.6665591 33.6742273 5.9933881 4.6665591 33.6742273 5.9933881 +0.4980 193.6347503 33.5093160 5.9668756 3.6347503 33.5093160 5.9668756 +0.4990 192.6096969 33.3454163 5.9405131 2.6096969 33.3454163 5.9405131 +0.5000 191.5913454 33.1825205 5.9142996 1.5913454 33.1825205 5.9142996 +0.5010 190.5796426 33.0206211 5.8882340 0.5796426 33.0206211 5.8882340 +0.5020 189.5745362 32.8597108 5.8623152 9.5745362 32.8597108 5.8623152 +0.5030 188.5759739 32.6997821 5.8365423 8.5759739 32.6997821 5.8365423 +0.5040 187.5839043 32.5408276 5.8109141 7.5839043 32.5408276 5.8109141 +0.5050 186.5982763 32.3828402 5.7854298 6.5982763 32.3828402 5.7854298 +0.5060 185.6190392 32.2258127 5.7600882 5.6190392 32.2258127 5.7600882 +0.5070 184.6461430 32.0697379 5.7348884 4.6461430 32.0697379 5.7348884 +0.5080 183.6795379 31.9146087 5.7098294 3.6795379 31.9146087 5.7098294 +0.5090 182.7191747 31.7604183 5.6849101 2.7191747 31.7604183 5.6849101 +0.5100 181.7650046 31.6071595 5.6601298 1.7650046 31.6071595 5.6601298 +0.5110 180.8169795 31.4548256 5.6354873 0.8169795 31.4548256 5.6354873 +0.5120 179.8750513 31.3034097 5.6109817 9.8750513 31.3034097 5.6109817 +0.5130 178.9391727 31.1529051 5.5866120 8.9391727 31.1529051 5.5866120 +0.5140 178.0092967 31.0033051 5.5623774 8.0092967 31.0033051 5.5623774 +0.5150 177.0853767 30.8546031 5.5382769 7.0853767 30.8546031 5.5382769 +0.5160 176.1673666 30.7067924 5.5143096 6.1673666 30.7067924 5.5143096 +0.5170 175.2552207 30.5598666 5.4904745 5.2552207 30.5598666 5.4904745 +0.5180 174.3488937 30.4138191 5.4667707 4.3488937 30.4138191 5.4667707 +0.5190 173.4483407 30.2686436 5.4431974 3.4483407 30.2686436 5.4431974 +0.5200 172.5535172 30.1243337 5.4197536 2.5535172 30.1243337 5.4197536 +0.5210 171.6643793 29.9808832 5.3964385 1.6643793 29.9808832 5.3964385 +0.5220 170.7808831 29.8382858 5.3732511 0.7808831 29.8382858 5.3732511 +0.5230 169.9029855 29.6965352 5.3501907 9.9029855 29.6965352 5.3501907 +0.5240 169.0306436 29.5556254 5.3272563 9.0306436 29.5556254 5.3272563 +0.5250 168.1638148 29.4155503 5.3044471 8.1638148 29.4155503 5.3044471 +0.5260 167.3024571 29.2763038 5.2817622 7.3024571 29.2763038 5.2817622 +0.5270 166.4465288 29.1378801 5.2592009 6.4465288 29.1378801 5.2592009 +0.5280 165.5959885 29.0002730 5.2367621 5.5959885 29.0002730 5.2367621 +0.5290 164.7507952 28.8634769 5.2144453 4.7507952 28.8634769 5.2144453 +0.5300 163.9109083 28.7274858 5.1922494 3.9109083 28.7274858 5.1922494 +0.5310 163.0762876 28.5922939 5.1701737 3.0762876 28.5922939 5.1701737 +0.5320 162.2468931 28.4578957 5.1482173 2.2468931 28.4578957 5.1482173 +0.5330 161.4226854 28.3242853 5.1263796 1.4226854 28.3242853 5.1263796 +0.5340 160.6036253 28.1914571 5.1046596 0.6036253 28.1914571 5.1046596 +0.5350 159.7896739 28.0594056 5.0830567 9.7896739 28.0594056 5.0830567 +0.5360 158.9807927 27.9281253 5.0615699 8.9807927 27.9281253 5.0615699 +0.5370 158.1769437 27.7976106 5.0401986 8.1769437 27.7976106 5.0401986 +0.5380 157.3780890 27.6678562 5.0189420 7.3780890 27.6678562 5.0189420 +0.5390 156.5841911 27.5388565 4.9977992 6.5841911 27.5388565 4.9977992 +0.5400 155.7952129 27.4106063 4.9767696 5.7952129 27.4106063 4.9767696 +0.5410 155.0111177 27.2831003 4.9558524 5.0111177 27.2831003 4.9558524 +0.5420 154.2318688 27.1563333 4.9350468 4.2318688 27.1563333 4.9350468 +0.5430 153.4574302 27.0302999 4.9143522 3.4574302 27.0302999 4.9143522 +0.5440 152.6877660 26.9049950 4.8937677 2.6877660 26.9049950 4.8937677 +0.5450 151.9228407 26.7804136 4.8732926 1.9228407 26.7804136 4.8732926 +0.5460 151.1626191 26.6565505 4.8529263 1.1626191 26.6565505 4.8529263 +0.5470 150.4070662 26.5334007 4.8326680 0.4070662 26.5334007 4.8326680 +0.5480 149.6561475 26.4109592 4.8125170 9.6561475 26.4109592 4.8125170 +0.5490 148.9098286 26.2892210 4.7924726 8.9098286 26.2892210 4.7924726 +0.5500 148.1680756 26.1681812 4.7725341 8.1680756 26.1681812 4.7725341 +0.5510 147.4308547 26.0478349 4.7527008 7.4308547 26.0478349 4.7527008 +0.5520 146.6981325 25.9281774 4.7329721 6.6981325 25.9281774 4.7329721 +0.5530 145.9698760 25.8092038 4.7133471 5.9698760 25.8092038 4.7133471 +0.5540 145.2460523 25.6909093 4.6938253 5.2460523 25.6909093 4.6938253 +0.5550 144.5266287 25.5732893 4.6744060 4.5266287 25.5732893 4.6744060 +0.5560 143.8115732 25.4563391 4.6550885 3.8115732 25.4563391 4.6550885 +0.5570 143.1008536 25.3400540 4.6358722 3.1008536 25.3400540 4.6358722 +0.5580 142.3944382 25.2244294 4.6167564 2.3944382 25.2244294 4.6167564 +0.5590 141.6922957 25.1094608 4.5977404 1.6922957 25.1094608 4.5977404 +0.5600 140.9943949 24.9951437 4.5788237 0.9943949 24.9951437 4.5788237 +0.5610 140.3007048 24.8814735 4.5600055 0.3007048 24.8814735 4.5600055 +0.5620 139.6111948 24.7684458 4.5412852 9.6111948 24.7684458 4.5412852 +0.5630 138.9258345 24.6560562 4.5226623 8.9258345 24.6560562 4.5226623 +0.5640 138.2445939 24.5443004 4.5041360 8.2445939 24.5443004 4.5041360 +0.5650 137.5674430 24.4331739 4.4857058 7.5674430 24.4331739 4.4857058 +0.5660 136.8943523 24.3226725 4.4673710 6.8943523 24.3226725 4.4673710 +0.5670 136.2252924 24.2127919 4.4491311 6.2252924 24.2127919 4.4491311 +0.5680 135.5602343 24.1035279 4.4309854 5.5602343 24.1035279 4.4309854 +0.5690 134.8991490 23.9948762 4.4129333 4.8991490 23.9948762 4.4129333 +0.5700 134.2420080 23.8868327 4.3949742 4.2420080 23.8868327 4.3949742 +0.5710 133.5887829 23.7793933 4.3771076 3.5887829 23.7793933 4.3771076 +0.5720 132.9394456 23.6725538 4.3593328 2.9394456 23.6725538 4.3593328 +0.5730 132.2939681 23.5663102 4.3416493 2.2939681 23.5663102 4.3416493 +0.5740 131.6523229 23.4606585 4.3240564 1.6523229 23.4606585 4.3240564 +0.5750 131.0144825 23.3555946 4.3065537 1.0144825 23.3555946 4.3065537 +0.5760 130.3804198 23.2511146 4.2891405 0.3804198 23.2511146 4.2891405 +0.5770 129.7501078 23.1472146 4.2718163 9.7501078 23.1472146 4.2718163 +0.5780 129.1235197 23.0438906 4.2545805 9.1235197 23.0438906 4.2545805 +0.5790 128.5006290 22.9411387 4.2374326 8.5006290 22.9411387 4.2374326 +0.5800 127.8814095 22.8389551 4.2203720 7.8814095 22.8389551 4.2203720 +0.5810 127.2658351 22.7373361 4.2033981 7.2658351 22.7373361 4.2033981 +0.5820 126.6538800 22.6362777 4.1865105 6.6538800 22.6362777 4.1865105 +0.5830 126.0455185 22.5357763 4.1697085 6.0455185 22.5357763 4.1697085 +0.5840 125.4407252 22.4358282 4.1529917 5.4407252 22.4358282 4.1529917 +0.5850 124.8394749 22.3364296 4.1363594 4.8394749 22.3364296 4.1363594 +0.5860 124.2417426 22.2375768 4.1198113 4.2417426 22.2375768 4.1198113 +0.5870 123.6475035 22.1392663 4.1033467 3.6475035 22.1392663 4.1033467 +0.5880 123.0567330 22.0414944 4.0869651 3.0567330 22.0414944 4.0869651 +0.5890 122.4694068 21.9442576 4.0706661 2.4694068 21.9442576 4.0706661 +0.5900 121.8855007 21.8475522 4.0544491 1.8855007 21.8475522 4.0544491 +0.5910 121.3049907 21.7513748 4.0383135 1.3049907 21.7513748 4.0383135 +0.5920 120.7278530 21.6557219 4.0222590 0.7278530 21.6557219 4.0222590 +0.5930 120.1540640 21.5605900 4.0062849 0.1540640 21.5605900 4.0062849 +0.5940 119.5836004 21.4659757 3.9903909 9.5836004 21.4659757 3.9903909 +0.5950 119.0164390 21.3718756 3.9745764 9.0164390 21.3718756 3.9745764 +0.5960 118.4525566 21.2782862 3.9588408 8.4525566 21.2782862 3.9588408 +0.5970 117.8919307 21.1852042 3.9431838 7.8919307 21.1852042 3.9431838 +0.5980 117.3345384 21.0926262 3.9276049 7.3345384 21.0926262 3.9276049 +0.5990 116.7803573 21.0005491 3.9121035 6.7803573 21.0005491 3.9121035 +0.6000 116.2293653 20.9089694 3.8966792 6.2293653 20.9089694 3.8966792 +0.6010 115.6815402 20.8178839 3.8813315 5.6815402 20.8178839 3.8813315 +0.6020 115.1368600 20.7272894 3.8660600 5.1368600 20.7272894 3.8660600 +0.6030 114.5953032 20.6371827 3.8508642 4.5953032 20.6371827 3.8508642 +0.6040 114.0568481 20.5475606 3.8357436 4.0568481 20.5475606 3.8357436 +0.6050 113.5214734 20.4584200 3.8206977 3.5214734 20.4584200 3.8206977 +0.6060 112.9891578 20.3697576 3.8057262 2.9891578 20.3697576 3.8057262 +0.6070 112.4598804 20.2815705 3.7908285 2.4598804 20.2815705 3.7908285 +0.6080 111.9336202 20.1938554 3.7760043 1.9336202 20.1938554 3.7760043 +0.6090 111.4103567 20.1066094 3.7612530 1.4103567 20.1066094 3.7612530 +0.6100 110.8900692 20.0198294 3.7465743 0.8900692 20.0198294 3.7465743 +0.6110 110.3727375 19.9335124 3.7319676 0.3727375 19.9335124 3.7319676 +0.6120 109.8583413 19.8476555 3.7174326 9.8583413 19.8476555 3.7174326 +0.6130 109.3468606 19.7622555 3.7029688 9.3468606 19.7622555 3.7029688 +0.6140 108.8382755 19.6773096 3.6885758 8.8382755 19.6773096 3.6885758 +0.6150 108.3325663 19.5928150 3.6742532 8.3325663 19.5928150 3.6742532 +0.6160 107.8297135 19.5087685 3.6600006 7.8297135 19.5087685 3.6600006 +0.6170 107.3296977 19.4251675 3.6458174 7.3296977 19.4251675 3.6458174 +0.6180 106.8324997 19.3420090 3.6317034 6.8324997 19.3420090 3.6317034 +0.6190 106.3381002 19.2592901 3.6176581 6.3381002 19.2592901 3.6176581 +0.6200 105.8464805 19.1770082 3.6036811 5.8464805 19.1770082 3.6036811 +0.6210 105.3576217 19.0951603 3.5897719 5.3576217 19.0951603 3.5897719 +0.6220 104.8715052 19.0137437 3.5759302 4.8715052 19.0137437 3.5759302 +0.6230 104.3881125 18.9327557 3.5621556 4.3881125 18.9327557 3.5621556 +0.6240 103.9074253 18.8521936 3.5484477 3.9074253 18.8521936 3.5484477 +0.6250 103.4294254 18.7720545 3.5348060 3.4294254 18.7720545 3.5348060 +0.6260 102.9540947 18.6923359 3.5212303 2.9540947 18.6923359 3.5212303 +0.6270 102.4814152 18.6130350 3.5077200 2.4814152 18.6130350 3.5077200 +0.6280 102.0113694 18.5341493 3.4942748 2.0113694 18.5341493 3.4942748 +0.6290 101.5439394 18.4556760 3.4808944 1.5439394 18.4556760 3.4808944 +0.6300 101.0791079 18.3776126 3.4675783 1.0791079 18.3776126 3.4675783 +0.6310 100.6168575 18.2999565 3.4543261 0.6168575 18.2999565 3.4543261 +0.6320 100.1571709 18.2227051 3.4411376 0.1571709 18.2227051 3.4411376 +0.6330 99.7000311 18.1458558 3.4280123 9.7000311 18.1458558 3.4280123 +0.6340 99.2454212 18.0694062 3.4149498 9.2454212 18.0694062 3.4149498 +0.6350 98.7933242 17.9933537 3.4019497 8.7933242 17.9933537 3.4019497 +0.6360 98.3437237 17.9176958 3.3890118 8.3437237 17.9176958 3.3890118 +0.6370 97.8966029 17.8424301 3.3761357 7.8966029 17.8424301 3.3761357 +0.6380 97.4519455 17.7675540 3.3633209 7.4519455 17.7675540 3.3633209 +0.6390 97.0097351 17.6930652 3.3505671 7.0097351 17.6930652 3.3505671 +0.6400 96.5699557 17.6189612 3.3378740 6.5699557 17.6189612 3.3378740 +0.6410 96.1325912 17.5452397 3.3252412 6.1325912 17.5452397 3.3252412 +0.6420 95.6976256 17.4718981 3.3126684 5.6976256 17.4718981 3.3126684 +0.6430 95.2650431 17.3989342 3.3001551 5.2650431 17.3989342 3.3001551 +0.6440 94.8348282 17.3263457 3.2877012 4.8348282 17.3263457 3.2877012 +0.6450 94.4069651 17.2541301 3.2753062 4.4069651 17.2541301 3.2753062 +0.6460 93.9814386 17.1822851 3.2629697 3.9814386 17.1822851 3.2629697 +0.6470 93.5582333 17.1108086 3.2506915 3.5582333 17.1108086 3.2506915 +0.6480 93.1373339 17.0396981 3.2384713 3.1373339 17.0396981 3.2384713 +0.6490 92.7187255 16.9689514 3.2263085 2.7187255 16.9689514 3.2263085 +0.6500 92.3023930 16.8985663 3.2142031 2.3023930 16.8985663 3.2142031 +0.6510 91.8883217 16.8285405 3.2021545 1.8883217 16.8285405 3.2021545 +0.6520 91.4764967 16.7588718 3.1901625 1.4764967 16.7588718 3.1901625 +0.6530 91.0669035 16.6895580 3.1782268 1.0669035 16.6895580 3.1782268 +0.6540 90.6595276 16.6205970 3.1663470 0.6595276 16.6205970 3.1663470 +0.6550 90.2543545 16.5519865 3.1545229 0.2543545 16.5519865 3.1545229 +0.6560 89.8513700 16.4837244 3.1427540 9.8513700 16.4837244 3.1427540 +0.6570 89.4505599 16.4158086 3.1310401 9.4505599 16.4158086 3.1310401 +0.6580 89.0519101 16.3482369 3.1193809 9.0519101 16.3482369 3.1193809 +0.6590 88.6554066 16.2810073 3.1077761 8.6554066 16.2810073 3.1077761 +0.6600 88.2610357 16.2141176 3.0962252 8.2610357 16.2141176 3.0962252 +0.6610 87.8687835 16.1475658 3.0847282 7.8687835 16.1475658 3.0847282 +0.6620 87.4786365 16.0813498 3.0732845 7.4786365 16.0813498 3.0732845 +0.6630 87.0905810 16.0154675 3.0618940 7.0905810 16.0154675 3.0618940 +0.6640 86.7046036 15.9499170 3.0505563 6.7046036 15.9499170 3.0505563 +0.6650 86.3206910 15.8846962 3.0392712 6.3206910 15.8846962 3.0392712 +0.6660 85.9388300 15.8198031 3.0280382 5.9388300 15.8198031 3.0280382 +0.6670 85.5590074 15.7552357 3.0168572 5.5590074 15.7552357 3.0168572 +0.6680 85.1812101 15.6909920 3.0057279 5.1812101 15.6909920 3.0057279 +0.6690 84.8054253 15.6270702 2.9946499 4.8054253 15.6270702 2.9946499 +0.6700 84.4316400 15.5634682 2.9836230 4.4316400 15.5634682 2.9836230 +0.6710 84.0598416 15.5001840 2.9726468 4.0598416 15.5001840 2.9726468 +0.6720 83.6900173 15.4372159 2.9617211 3.6900173 15.4372159 2.9617211 +0.6730 83.3221547 15.3745619 2.9508456 3.3221547 15.3745619 2.9508456 +0.6740 82.9562412 15.3122200 2.9400201 2.9562412 15.3122200 2.9400201 +0.6750 82.5922645 15.2501885 2.9292442 2.5922645 15.2501885 2.9292442 +0.6760 82.2302123 15.1884654 2.9185176 2.2302123 15.1884654 2.9185176 +0.6770 81.8700724 15.1270489 2.9078402 1.8700724 15.1270489 2.9078402 +0.6780 81.5118328 15.0659372 2.8972116 1.5118328 15.0659372 2.8972116 +0.6790 81.1554813 15.0051284 2.8866315 1.1554813 15.0051284 2.8866315 +0.6800 80.8010062 14.9446207 2.8760997 0.8010062 14.9446207 2.8760997 +0.6810 80.4483955 14.8844124 2.8656158 0.4483955 14.8844124 2.8656158 +0.6820 80.0976376 14.8245015 2.8551798 0.0976376 14.8245015 2.8551798 +0.6830 79.7487207 14.7648865 2.8447912 9.7487207 14.7648865 2.8447912 +0.6840 79.4016333 14.7055654 2.8344498 9.4016333 14.7055654 2.8344498 +0.6850 79.0563639 14.6465366 2.8241553 9.0563639 14.6465366 2.8241553 +0.6860 78.7129012 14.5877983 2.8139075 8.7129012 14.5877983 2.8139075 +0.6870 78.3712338 14.5293488 2.8037062 8.3712338 14.5293488 2.8037062 +0.6880 78.0313504 14.4711864 2.7935510 8.0313504 14.4711864 2.7935510 +0.6890 77.6932400 14.4133093 2.7834418 7.6932400 14.4133093 2.7834418 +0.6900 77.3568914 14.3557160 2.7733782 7.3568914 14.3557160 2.7733782 +0.6910 77.0222938 14.2984046 2.7633600 7.0222938 14.2984046 2.7633600 +0.6920 76.6894360 14.2413736 2.7533870 6.6894360 14.2413736 2.7533870 +0.6930 76.3583075 14.1846212 2.7434589 6.3583075 14.1846212 2.7434589 +0.6940 76.0288973 14.1281459 2.7335755 6.0288973 14.1281459 2.7335755 +0.6950 75.7011949 14.0719461 2.7237365 5.7011949 14.0719461 2.7237365 +0.6960 75.3751896 14.0160200 2.7139416 5.3751896 14.0160200 2.7139416 +0.6970 75.0508709 13.9603661 2.7041907 5.0508709 13.9603661 2.7041907 +0.6980 74.7282285 13.9049827 2.6944835 4.7282285 13.9049827 2.6944835 +0.6990 74.4072519 13.8498684 2.6848198 4.4072519 13.8498684 2.6848198 +0.7000 74.0879308 13.7950215 2.6751993 4.0879308 13.7950215 2.6751993 +0.7010 73.7702551 13.7404405 2.6656217 3.7702551 13.7404405 2.6656217 +0.7020 73.4542145 13.6861238 2.6560870 3.4542145 13.6861238 2.6560870 +0.7030 73.1397992 13.6320698 2.6465947 3.1397992 13.6320698 2.6465947 +0.7040 72.8269989 13.5782771 2.6371447 2.8269989 13.5782771 2.6371447 +0.7050 72.5158039 13.5247441 2.6277368 2.5158039 13.5247441 2.6277368 +0.7060 72.2062043 13.4714693 2.6183707 2.2062043 13.4714693 2.6183707 +0.7070 71.8981904 13.4184512 2.6090463 1.8981904 13.4184512 2.6090463 +0.7080 71.5917523 13.3656882 2.5997632 1.5917523 13.3656882 2.5997632 +0.7090 71.2868805 13.3131791 2.5905212 1.2868805 13.3131791 2.5905212 +0.7100 70.9835654 13.2609221 2.5813202 0.9835654 13.2609221 2.5813202 +0.7110 70.6817975 13.2089160 2.5721599 0.6817975 13.2089160 2.5721599 +0.7120 70.3815674 13.1571592 2.5630401 0.3815674 13.1571592 2.5630401 +0.7130 70.0828658 13.1056503 2.5539606 0.0828658 13.1056503 2.5539606 +0.7140 69.7856832 13.0543879 2.5449211 9.7856832 13.0543879 2.5449211 +0.7150 69.4900106 13.0033705 2.5359215 9.4900106 13.0033705 2.5359215 +0.7160 69.1958387 12.9525968 2.5269615 9.1958387 12.9525968 2.5269615 +0.7170 68.9031585 12.9020653 2.5180409 8.9031585 12.9020653 2.5180409 +0.7180 68.6119608 12.8517746 2.5091596 8.6119608 12.8517746 2.5091596 +0.7190 68.3222368 12.8017234 2.5003172 8.3222368 12.8017234 2.5003172 +0.7200 68.0339775 12.7519103 2.4915136 8.0339775 12.7519103 2.4915136 +0.7210 67.7471742 12.7023340 2.4827486 7.7471742 12.7023340 2.4827486 +0.7220 67.4618179 12.6529930 2.4740220 7.4618179 12.6529930 2.4740220 +0.7230 67.1779001 12.6038860 2.4653335 7.1779001 12.6038860 2.4653335 +0.7240 66.8954120 12.5550117 2.4566830 6.8954120 12.5550117 2.4566830 +0.7250 66.6143450 12.5063688 2.4480703 6.6143450 12.5063688 2.4480703 +0.7260 66.3346907 12.4579559 2.4394952 6.3346907 12.4579559 2.4394952 +0.7270 66.0564406 12.4097717 2.4309574 6.0564406 12.4097717 2.4309574 +0.7280 65.7795861 12.3618149 2.4224568 5.7795861 12.3618149 2.4224568 +0.7290 65.5041191 12.3140843 2.4139931 5.5041191 12.3140843 2.4139931 +0.7300 65.2300311 12.2665786 2.4055663 5.2300311 12.2665786 2.4055663 +0.7310 64.9573141 12.2192964 2.3971760 4.9573141 12.2192964 2.3971760 +0.7320 64.6859596 12.1722365 2.3888221 4.6859596 12.1722365 2.3888221 +0.7330 64.4159598 12.1253977 2.3805044 4.4159598 12.1253977 2.3805044 +0.7340 64.1473064 12.0787787 2.3722227 4.1473064 12.0787787 2.3722227 +0.7350 63.8799915 12.0323782 2.3639768 3.8799915 12.0323782 2.3639768 +0.7360 63.6140072 11.9861951 2.3557666 3.6140072 11.9861951 2.3557666 +0.7370 63.3493455 11.9402280 2.3475917 3.3493455 11.9402280 2.3475917 +0.7380 63.0859986 11.8944759 2.3394522 3.0859986 11.8944759 2.3394522 +0.7390 62.8239587 11.8489374 2.3313477 2.8239587 11.8489374 2.3313477 +0.7400 62.5632181 11.8036114 2.3232781 2.5632181 11.8036114 2.3232781 +0.7410 62.3037690 11.7584966 2.3152432 2.3037690 11.7584966 2.3152432 +0.7420 62.0456040 11.7135920 2.3072429 2.0456040 11.7135920 2.3072429 +0.7430 61.7887153 11.6688963 2.2992769 1.7887153 11.6688963 2.2992769 +0.7440 61.5330955 11.6244083 2.2913450 1.5330955 11.6244083 2.2913450 +0.7450 61.2787372 11.5801269 2.2834471 1.2787372 11.5801269 2.2834471 +0.7460 61.0256328 11.5360509 2.2755831 1.0256328 11.5360509 2.2755831 +0.7470 60.7737751 11.4921792 2.2677526 0.7737751 11.4921792 2.2677526 +0.7480 60.5231566 11.4485107 2.2599557 0.5231566 11.4485107 2.2599557 +0.7490 60.2737703 11.4050442 2.2521920 0.2737703 11.4050442 2.2521920 +0.7500 60.0256087 11.3617785 2.2444614 0.0256087 11.3617785 2.2444614 +0.7510 59.7786649 11.3187127 2.2367638 9.7786649 11.3187127 2.2367638 +0.7520 59.5329316 11.2758455 2.2290989 9.5329316 11.2758455 2.2290989 +0.7530 59.2884018 11.2331758 2.2214666 9.2884018 11.2331758 2.2214666 +0.7540 59.0450684 11.1907026 2.2138668 9.0450684 11.1907026 2.2138668 +0.7550 58.8029246 11.1484248 2.2062993 8.8029246 11.1484248 2.2062993 +0.7560 58.5619634 11.1063413 2.1987638 8.5619634 11.1063413 2.1987638 +0.7570 58.3221779 11.0644510 2.1912603 8.3221779 11.0644510 2.1912603 +0.7580 58.0835612 11.0227529 2.1837885 8.0835612 11.0227529 2.1837885 +0.7590 57.8461067 10.9812459 2.1763483 7.8461067 10.9812459 2.1763483 +0.7600 57.6098075 10.9399289 2.1689396 7.6098075 10.9399289 2.1689396 +0.7610 57.3746570 10.8988008 2.1615622 7.3746570 10.8988008 2.1615622 +0.7620 57.1406485 10.8578608 2.1542159 7.1406485 10.8578608 2.1542159 +0.7630 56.9077755 10.8171077 2.1469005 6.9077755 10.8171077 2.1469005 +0.7640 56.6760313 10.7765404 2.1396160 6.6760313 10.7765404 2.1396160 +0.7650 56.4454095 10.7361581 2.1323621 6.4454095 10.7361581 2.1323621 +0.7660 56.2159036 10.6959596 2.1251387 6.2159036 10.6959596 2.1251387 +0.7670 55.9875072 10.6559440 2.1179456 5.9875072 10.6559440 2.1179456 +0.7680 55.7602140 10.6161102 2.1107827 5.7602140 10.6161102 2.1107827 +0.7690 55.5340174 10.5764573 2.1036499 5.5340174 10.5764573 2.1036499 +0.7700 55.3089114 10.5369842 2.0965469 5.3089114 10.5369842 2.0965469 +0.7710 55.0848896 10.4976901 2.0894736 5.0848896 10.4976901 2.0894736 +0.7720 54.8619458 10.4585739 2.0824299 4.8619458 10.4585739 2.0824299 +0.7730 54.6400739 10.4196346 2.0754157 4.6400739 10.4196346 2.0754157 +0.7740 54.4192677 10.3808713 2.0684307 4.4192677 10.3808713 2.0684307 +0.7750 54.1995211 10.3422831 2.0614748 4.1995211 10.3422831 2.0614748 +0.7760 53.9808282 10.3038690 2.0545479 3.9808282 10.3038690 2.0545479 +0.7770 53.7631829 10.2656280 2.0476498 3.7631829 10.2656280 2.0476498 +0.7780 53.5465792 10.2275592 2.0407804 3.5465792 10.2275592 2.0407804 +0.7790 53.3310112 10.1896616 2.0339396 3.3310112 10.1896616 2.0339396 +0.7800 53.1164730 10.1519345 2.0271272 3.1164730 10.1519345 2.0271272 +0.7810 52.9029589 10.1143767 2.0203430 2.9029589 10.1143767 2.0203430 +0.7820 52.6904629 10.0769875 2.0135869 2.6904629 10.0769875 2.0135869 +0.7830 52.4789794 10.0397659 2.0068588 2.4789794 10.0397659 2.0068588 +0.7840 52.2685025 10.0027109 2.0001585 2.2685025 10.0027109 2.0001585 +0.7850 52.0590267 9.9658218 1.9934859 52.0590267 9.9658218 1.9934859 +0.7860 51.8505462 9.9290976 1.9868409 51.8505462 9.9290976 1.9868409 +0.7870 51.6430554 9.8925374 1.9802232 51.6430554 9.8925374 1.9802232 +0.7880 51.4365489 9.8561404 1.9736329 51.4365489 9.8561404 1.9736329 +0.7890 51.2310209 9.8199057 1.9670697 51.2310209 9.8199057 1.9670697 +0.7900 51.0264660 9.7838323 1.9605335 51.0264660 9.7838323 1.9605335 +0.7910 50.8228788 9.7479195 1.9540241 50.8228788 9.7479195 1.9540241 +0.7920 50.6202538 9.7121664 1.9475415 50.6202538 9.7121664 1.9475415 +0.7930 50.4185857 9.6765721 1.9410855 50.4185857 9.6765721 1.9410855 +0.7940 50.2178690 9.6411358 1.9346560 50.2178690 9.6411358 1.9346560 +0.7950 50.0180984 9.6058567 1.9282528 50.0180984 9.6058567 1.9282528 +0.7960 49.8192687 9.5707338 1.9218759 49.8192687 9.5707338 1.9218759 +0.7970 49.6213746 9.5357665 1.9155250 49.6213746 9.5357665 1.9155250 +0.7980 49.4244109 9.5009538 1.9092000 49.4244109 9.5009538 1.9092000 +0.7990 49.2283723 9.4662949 1.9029009 49.2283723 9.4662949 1.9029009 +0.8000 49.0332538 9.4317890 1.8966275 49.0332538 9.4317890 1.8966275 +0.8010 48.8390502 9.3974354 1.8903796 48.8390502 9.3974354 1.8903796 +0.8020 48.6457564 9.3632331 1.8841572 48.6457564 9.3632331 1.8841572 +0.8030 48.4533674 9.3291814 1.8779600 48.4533674 9.3291814 1.8779600 +0.8040 48.2618782 9.2952795 1.8717881 48.2618782 9.2952795 1.8717881 +0.8050 48.0712838 9.2615267 1.8656413 48.0712838 9.2615267 1.8656413 +0.8060 47.8815791 9.2279220 1.8595194 47.8815791 9.2279220 1.8595194 +0.8070 47.6927594 9.1944649 1.8534223 47.6927594 9.1944649 1.8534223 +0.8080 47.5048196 9.1611543 1.8473499 47.5048196 9.1611543 1.8473499 +0.8090 47.3177550 9.1279897 1.8413020 47.3177550 9.1279897 1.8413020 +0.8100 47.1315608 9.0949702 1.8352787 47.1315608 9.0949702 1.8352787 +0.8110 46.9462320 9.0620951 1.8292797 46.9462320 9.0620951 1.8292797 +0.8120 46.7617641 9.0293636 1.8233049 46.7617641 9.0293636 1.8233049 +0.8130 46.5781521 8.9967749 1.8173541 46.5781521 8.9967749 1.8173541 +0.8140 46.3953915 8.9643283 1.8114274 46.3953915 8.9643283 1.8114274 +0.8150 46.2134775 8.9320232 1.8055246 46.2134775 8.9320232 1.8055246 +0.8160 46.0324056 8.8998586 1.7996454 46.0324056 8.8998586 1.7996454 +0.8170 45.8521710 8.8678339 1.7937900 45.8521710 8.8678339 1.7937900 +0.8180 45.6727692 8.8359484 1.7879580 45.6727692 8.8359484 1.7879580 +0.8190 45.4941957 8.8042014 1.7821494 45.4941957 8.8042014 1.7821494 +0.8200 45.3164460 8.7725920 1.7763642 45.3164460 8.7725920 1.7763642 +0.8210 45.1395155 8.7411197 1.7706021 45.1395155 8.7411197 1.7706021 +0.8220 44.9633997 8.7097837 1.7648630 44.9633997 8.7097837 1.7648630 +0.8230 44.7880943 8.6785832 1.7591470 44.7880943 8.6785832 1.7591470 +0.8240 44.6135948 8.6475176 1.7534537 44.6135948 8.6475176 1.7534537 +0.8250 44.4398968 8.6165862 1.7477832 44.4398968 8.6165862 1.7477832 +0.8260 44.2669961 8.5857883 1.7421353 44.2669961 8.5857883 1.7421353 +0.8270 44.0948882 8.5551232 1.7365100 44.0948882 8.5551232 1.7365100 +0.8280 43.9235689 8.5245902 1.7309070 43.9235689 8.5245902 1.7309070 +0.8290 43.7530338 8.4941886 1.7253263 43.7530338 8.4941886 1.7253263 +0.8300 43.5832788 8.4639178 1.7197678 43.5832788 8.4639178 1.7197678 +0.8310 43.4142997 8.4337771 1.7142314 43.4142997 8.4337771 1.7142314 +0.8320 43.2460921 8.4037657 1.7087170 43.2460921 8.4037657 1.7087170 +0.8330 43.0786521 8.3738831 1.7032245 43.0786521 8.3738831 1.7032245 +0.8340 42.9119754 8.3441286 1.6977537 42.9119754 8.3441286 1.6977537 +0.8350 42.7460579 8.3145015 1.6923045 42.7460579 8.3145015 1.6923045 +0.8360 42.5808956 8.2850012 1.6868770 42.5808956 8.2850012 1.6868770 +0.8370 42.4164843 8.2556269 1.6814708 42.4164843 8.2556269 1.6814708 +0.8380 42.2528201 8.2263782 1.6760861 42.2528201 8.2263782 1.6760861 +0.8390 42.0898989 8.1972543 1.6707225 42.0898989 8.1972543 1.6707225 +0.8400 41.9277168 8.1682546 1.6653802 41.9277168 8.1682546 1.6653802 +0.8410 41.7662698 8.1393784 1.6600588 41.7662698 8.1393784 1.6600588 +0.8420 41.6055540 8.1106252 1.6547585 41.6055540 8.1106252 1.6547585 +0.8430 41.4455654 8.0819943 1.6494789 41.4455654 8.0819943 1.6494789 +0.8440 41.2863002 8.0534850 1.6442201 41.2863002 8.0534850 1.6442201 +0.8450 41.1277545 8.0250968 1.6389820 41.1277545 8.0250968 1.6389820 +0.8460 40.9699245 7.9968291 1.6337644 40.9699245 7.9968291 1.6337644 +0.8470 40.8128063 7.9686812 1.6285673 40.8128063 7.9686812 1.6285673 +0.8480 40.6563962 7.9406525 1.6233905 40.6563962 7.9406525 1.6233905 +0.8490 40.5006904 7.9127424 1.6182340 40.5006904 7.9127424 1.6182340 +0.8500 40.3456852 7.8849503 1.6130976 40.3456852 7.8849503 1.6130976 +0.8510 40.1913769 7.8572757 1.6079813 40.1913769 7.8572757 1.6079813 +0.8520 40.0377617 7.8297179 1.6028850 40.0377617 7.8297179 1.6028850 +0.8530 39.8848360 7.8022763 1.5978086 39.8848360 7.8022763 1.5978086 +0.8540 39.7325961 7.7749504 1.5927520 39.7325961 7.7749504 1.5927520 +0.8550 39.5810385 7.7477396 1.5877150 39.5810385 7.7477396 1.5877150 +0.8560 39.4301594 7.7206432 1.5826977 39.4301594 7.7206432 1.5826977 +0.8570 39.2799554 7.6936608 1.5776998 39.2799554 7.6936608 1.5776998 +0.8580 39.1304228 7.6667916 1.5727214 39.1304228 7.6667916 1.5727214 +0.8590 38.9815582 7.6400353 1.5677623 38.9815582 7.6400353 1.5677623 +0.8600 38.8333580 7.6133912 1.5628224 38.8333580 7.6133912 1.5628224 +0.8610 38.6858187 7.5868587 1.5579017 38.6858187 7.5868587 1.5579017 +0.8620 38.5389369 7.5604372 1.5530001 38.5389369 7.5604372 1.5530001 +0.8630 38.3927091 7.5341263 1.5481174 38.3927091 7.5341263 1.5481174 +0.8640 38.2471318 7.5079254 1.5432535 38.2471318 7.5079254 1.5432535 +0.8650 38.1022017 7.4818339 1.5384085 38.1022017 7.4818339 1.5384085 +0.8660 37.9579153 7.4558513 1.5335822 37.9579153 7.4558513 1.5335822 +0.8670 37.8142694 7.4299770 1.5287744 37.8142694 7.4299770 1.5287744 +0.8680 37.6712605 7.4042105 1.5239853 37.6712605 7.4042105 1.5239853 +0.8690 37.5288854 7.3785512 1.5192145 37.5288854 7.3785512 1.5192145 +0.8700 37.3871406 7.3529987 1.5144621 37.3871406 7.3529987 1.5144621 +0.8710 37.2460231 7.3275524 1.5097280 37.2460231 7.3275524 1.5097280 +0.8720 37.1055294 7.3022117 1.5050120 37.1055294 7.3022117 1.5050120 +0.8730 36.9656563 7.2769761 1.5003142 36.9656563 7.2769761 1.5003142 +0.8740 36.8264006 7.2518452 1.4956344 36.8264006 7.2518452 1.4956344 +0.8750 36.6877592 7.2268184 1.4909725 36.6877592 7.2268184 1.4909725 +0.8760 36.5497287 7.2018952 1.4863284 36.5497287 7.2018952 1.4863284 +0.8770 36.4123061 7.1770750 1.4817022 36.4123061 7.1770750 1.4817022 +0.8780 36.2754882 7.1523574 1.4770936 36.2754882 7.1523574 1.4770936 +0.8790 36.1392719 7.1277418 1.4725026 36.1392719 7.1277418 1.4725026 +0.8800 36.0036540 7.1032278 1.4679291 36.0036540 7.1032278 1.4679291 +0.8810 35.8686315 7.0788149 1.4633731 35.8686315 7.0788149 1.4633731 +0.8820 35.7342013 7.0545025 1.4588345 35.7342013 7.0545025 1.4588345 +0.8830 35.6003603 7.0302902 1.4543131 35.6003603 7.0302902 1.4543131 +0.8840 35.4671056 7.0061774 1.4498089 35.4671056 7.0061774 1.4498089 +0.8850 35.3344340 6.9821637 1.4453219 35.3344340 6.9821637 1.4453219 +0.8860 35.2023426 6.9582486 1.4408519 35.2023426 6.9582486 1.4408519 +0.8870 35.0708284 6.9344316 1.4363988 35.0708284 6.9344316 1.4363988 +0.8880 34.9398885 6.9107122 1.4319627 34.9398885 6.9107122 1.4319627 +0.8890 34.8095199 6.8870900 1.4275434 34.8095199 6.8870900 1.4275434 +0.8900 34.6797197 6.8635645 1.4231408 34.6797197 6.8635645 1.4231408 +0.8910 34.5504849 6.8401351 1.4187548 34.5504849 6.8401351 1.4187548 +0.8920 34.4218127 6.8168015 1.4143854 34.4218127 6.8168015 1.4143854 +0.8930 34.2937002 6.7935632 1.4100326 34.2937002 6.7935632 1.4100326 +0.8940 34.1661445 6.7704197 1.4056962 34.1661445 6.7704197 1.4056962 +0.8950 34.0391429 6.7473705 1.4013761 34.0391429 6.7473705 1.4013761 +0.8960 33.9126924 6.7244151 1.3970723 33.9126924 6.7244151 1.3970723 +0.8970 33.7867903 6.7015532 1.3927847 33.7867903 6.7015532 1.3927847 +0.8980 33.6614338 6.6787843 1.3885133 33.6614338 6.6787843 1.3885133 +0.8990 33.5366200 6.6561079 1.3842579 33.5366200 6.6561079 1.3842579 +0.9000 33.4123463 6.6335236 1.3800185 33.4123463 6.6335236 1.3800185 +0.9010 33.2886100 6.6110309 1.3757950 33.2886100 6.6110309 1.3757950 +0.9020 33.1654082 6.5886293 1.3715874 33.1654082 6.5886293 1.3715874 +0.9030 33.0427382 6.5663185 1.3673955 33.0427382 6.5663185 1.3673955 +0.9040 32.9205975 6.5440981 1.3632194 32.9205975 6.5440981 1.3632194 +0.9050 32.7989833 6.5219674 1.3590588 32.7989833 6.5219674 1.3590588 +0.9060 32.6778929 6.4999262 1.3549139 32.6778929 6.4999262 1.3549139 +0.9070 32.5573237 6.4779741 1.3507844 32.5573237 6.4779741 1.3507844 +0.9080 32.4372731 6.4561104 1.3466703 32.4372731 6.4561104 1.3466703 +0.9090 32.3177384 6.4343350 1.3425716 32.3177384 6.4343350 1.3425716 +0.9100 32.1987171 6.4126472 1.3384881 32.1987171 6.4126472 1.3384881 +0.9110 32.0802066 6.3910468 1.3344199 32.0802066 6.3910468 1.3344199 +0.9120 31.9622042 6.3695332 1.3303668 31.9622042 6.3695332 1.3303668 +0.9130 31.8447076 6.3481061 1.3263288 31.8447076 6.3481061 1.3263288 +0.9140 31.7277140 6.3267651 1.3223058 31.7277140 6.3267651 1.3223058 +0.9150 31.6112211 6.3055096 1.3182978 31.6112211 6.3055096 1.3182978 +0.9160 31.4952262 6.2843395 1.3143046 31.4952262 6.2843395 1.3143046 +0.9170 31.3797269 6.2632541 1.3103262 31.3797269 6.2632541 1.3103262 +0.9180 31.2647207 6.2422532 1.3063625 31.2647207 6.2422532 1.3063625 +0.9190 31.1502052 6.2213362 1.3024136 31.1502052 6.2213362 1.3024136 +0.9200 31.0361779 6.2005029 1.2984792 31.0361779 6.2005029 1.2984792 +0.9210 30.9226363 6.1797528 1.2945594 30.9226363 6.1797528 1.2945594 +0.9220 30.8095781 6.1590855 1.2906541 30.8095781 6.1590855 1.2906541 +0.9230 30.6970008 6.1385007 1.2867632 30.6970008 6.1385007 1.2867632 +0.9240 30.5849021 6.1179979 1.2828866 30.5849021 6.1179979 1.2828866 +0.9250 30.4732795 6.0975767 1.2790243 30.4732795 6.0975767 1.2790243 +0.9260 30.3621308 6.0772368 1.2751763 30.3621308 6.0772368 1.2751763 +0.9270 30.2514534 6.0569777 1.2713424 30.2514534 6.0569777 1.2713424 +0.9280 30.1412452 6.0367991 1.2675226 30.1412452 6.0367991 1.2675226 +0.9290 30.0315037 6.0167007 1.2637168 30.0315037 6.0167007 1.2637168 +0.9300 29.9222268 5.9966820 1.2599250 29.9222268 5.9966820 1.2599250 +0.9310 29.8134120 5.9767426 1.2561471 29.8134120 5.9767426 1.2561471 +0.9320 29.7050570 5.9568822 1.2523831 29.7050570 5.9568822 1.2523831 +0.9330 29.5971597 5.9371004 1.2486328 29.5971597 5.9371004 1.2486328 +0.9340 29.4897178 5.9173969 1.2448963 29.4897178 5.9173969 1.2448963 +0.9350 29.3827290 5.8977712 1.2411735 29.3827290 5.8977712 1.2411735 +0.9360 29.2761910 5.8782230 1.2374642 29.2761910 5.8782230 1.2374642 +0.9370 29.1701017 5.8587520 1.2337685 29.1701017 5.8587520 1.2337685 +0.9380 29.0644589 5.8393577 1.2300863 29.0644589 5.8393577 1.2300863 +0.9390 28.9592603 5.8200398 1.2264175 28.9592603 5.8200398 1.2264175 +0.9400 28.8545038 5.8007980 1.2227621 28.8545038 5.8007980 1.2227621 +0.9410 28.7501872 5.7816319 1.2191200 28.7501872 5.7816319 1.2191200 +0.9420 28.6463084 5.7625412 1.2154912 28.6463084 5.7625412 1.2154912 +0.9430 28.5428651 5.7435254 1.2118755 28.5428651 5.7435254 1.2118755 +0.9440 28.4398554 5.7245843 1.2082730 28.4398554 5.7245843 1.2082730 +0.9450 28.3372770 5.7057175 1.2046836 28.3372770 5.7057175 1.2046836 +0.9460 28.2351278 5.6869246 1.2011071 28.2351278 5.6869246 1.2011071 +0.9470 28.1334058 5.6682053 1.1975437 28.1334058 5.6682053 1.1975437 +0.9480 28.0321089 5.6495593 1.1939931 28.0321089 5.6495593 1.1939931 +0.9490 27.9312350 5.6309862 1.1904554 27.9312350 5.6309862 1.1904554 +0.9500 27.8307820 5.6124856 1.1869305 27.8307820 5.6124856 1.1869305 +0.9510 27.7307479 5.5940574 1.1834183 27.7307479 5.5940574 1.1834183 +0.9520 27.6311306 5.5757010 1.1799189 27.6311306 5.5757010 1.1799189 +0.9530 27.5319282 5.5574162 1.1764320 27.5319282 5.5574162 1.1764320 +0.9540 27.4331386 5.5392026 1.1729577 27.4331386 5.5392026 1.1729577 +0.9550 27.3347598 5.5210600 1.1694960 27.3347598 5.5210600 1.1694960 +0.9560 27.2367898 5.5029879 1.1660467 27.2367898 5.5029879 1.1660467 +0.9570 27.1392266 5.4849861 1.1626098 27.1392266 5.4849861 1.1626098 +0.9580 27.0420683 5.4670542 1.1591853 27.0420683 5.4670542 1.1591853 +0.9590 26.9453130 5.4491919 1.1557730 26.9453130 5.4491919 1.1557730 +0.9600 26.8489586 5.4313990 1.1523731 26.8489586 5.4313990 1.1523731 +0.9610 26.7530032 5.4136749 1.1489853 26.7530032 5.4136749 1.1489853 +0.9620 26.6574449 5.3960196 1.1456096 26.6574449 5.3960196 1.1456096 +0.9630 26.5622819 5.3784326 1.1422461 26.5622819 5.3784326 1.1422461 +0.9640 26.4675121 5.3609136 1.1388946 26.4675121 5.3609136 1.1388946 +0.9650 26.3731337 5.3434623 1.1355551 26.3731337 5.3434623 1.1355551 +0.9660 26.2791448 5.3260784 1.1322275 26.2791448 5.3260784 1.1322275 +0.9670 26.1855436 5.3087616 1.1289118 26.1855436 5.3087616 1.1289118 +0.9680 26.0923281 5.2915115 1.1256080 26.0923281 5.2915115 1.1256080 +0.9690 25.9994966 5.2743280 1.1223159 25.9994966 5.2743280 1.1223159 +0.9700 25.9070472 5.2572106 1.1190356 25.9070472 5.2572106 1.1190356 +0.9710 25.8149780 5.2401591 1.1157669 25.8149780 5.2401591 1.1157669 +0.9720 25.7232873 5.2231731 1.1125099 25.7232873 5.2231731 1.1125099 +0.9730 25.6319732 5.2062525 1.1092644 25.6319732 5.2062525 1.1092644 +0.9740 25.5410340 5.1893967 1.1060305 25.5410340 5.1893967 1.1060305 +0.9750 25.4504678 5.1726057 1.1028081 25.4504678 5.1726057 1.1028081 +0.9760 25.3602728 5.1558791 1.0995971 25.3602728 5.1558791 1.0995971 +0.9770 25.2704474 5.1392165 1.0963975 25.2704474 5.1392165 1.0963975 +0.9780 25.1809897 5.1226177 1.0932092 25.1809897 5.1226177 1.0932092 +0.9790 25.0918980 5.1060825 1.0900323 25.0918980 5.1060825 1.0900323 +0.9800 25.0031705 5.0896104 1.0868665 25.0031705 5.0896104 1.0868665 +0.9810 24.9148055 5.0732013 1.0837120 24.9148055 5.0732013 1.0837120 +0.9820 24.8268013 5.0568548 1.0805686 24.8268013 5.0568548 1.0805686 +0.9830 24.7391563 5.0405707 1.0774363 24.7391563 5.0405707 1.0774363 +0.9840 24.6518686 5.0243487 1.0743150 24.6518686 5.0243487 1.0743150 +0.9850 24.5649365 5.0081885 1.0712048 24.5649365 5.0081885 1.0712048 +0.9860 24.4783585 4.9920898 1.0681055 24.4783585 4.9920898 1.0681055 +0.9870 24.3921328 4.9760523 1.0650171 24.3921328 4.9760523 1.0650171 +0.9880 24.3062578 4.9600758 1.0619396 24.3062578 4.9600758 1.0619396 +0.9890 24.2207318 4.9441600 1.0588729 24.2207318 4.9441600 1.0588729 +0.9900 24.1355532 4.9283046 1.0558169 24.1355532 4.9283046 1.0558169 +0.9910 24.0507203 4.9125093 1.0527717 24.0507203 4.9125093 1.0527717 +0.9920 23.9662314 4.8967740 1.0497372 23.9662314 4.8967740 1.0497372 +0.9930 23.8820851 4.8810982 1.0467134 23.8820851 4.8810982 1.0467134 +0.9940 23.7982796 4.8654818 1.0437001 23.7982796 4.8654818 1.0437001 +0.9950 23.7148134 4.8499244 1.0406973 23.7148134 4.8499244 1.0406973 +0.9960 23.6316848 4.8344259 1.0377051 23.6316848 4.8344259 1.0377051 +0.9970 23.5488924 4.8189859 1.0347233 23.5488924 4.8189859 1.0347233 +0.9980 23.4664344 4.8036042 1.0317519 23.4664344 4.8036042 1.0317519 +0.9990 23.3843094 4.7882805 1.0287909 23.3843094 4.7882805 1.0287909 +1.0000 23.3025158 4.7730145 1.0258403 23.3025158 4.7730145 1.0258403