diff --git a/README.md b/README.md new file mode 100644 index 0000000..0e2a587 --- /dev/null +++ b/README.md @@ -0,0 +1 @@ +# Deep Symbolic Regression \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..eee0691 --- /dev/null +++ b/setup.py @@ -0,0 +1,22 @@ +import os +from setuptools import setup, find_packages + +# Utility function to read the README file. +# Used for the long_description. It's nice, because now 1) we have a top level +# README file and 2) it's easier to type in the README file than to put a raw +# string in below ... +def read(fname): + return open(os.path.join(os.path.dirname(__file__), fname)).read() + +setup( + name = "symbolicregression", + version = "0.0.1", + author = "Pierre-Alexandre Kamienny", + author_email = "pakamienny@fb.com", + description = ("Performing Symbolic Regression with Transformers"), + license = "BSD", + keywords = "symbolic regression, transformers", + url = "", + packages=find_packages(), + long_description=read('README.md'), +) diff --git a/symbolicregression/__init__.py b/symbolicregression/__init__.py new file mode 100644 index 0000000..6bc3959 --- /dev/null +++ b/symbolicregression/__init__.py @@ -0,0 +1 @@ +from . import model diff --git a/symbolicregression/envs/__init__.py b/symbolicregression/envs/__init__.py new file mode 100644 index 0000000..a093d87 --- /dev/null +++ b/symbolicregression/envs/__init__.py @@ -0,0 +1,34 @@ +# Copyright (c) 2020-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from logging import getLogger + +# from .generators import operators_conv, Node +from .environment import FunctionEnvironment + +logger = getLogger() + + +ENVS = { + "functions": FunctionEnvironment, +} + + +def build_env(params): + """ + Build environment. + """ + env = ENVS[params.env_name](params) + + # tasks + tasks = [x for x in params.tasks.split(",") if len(x) > 0] + assert len(tasks) == len(set(tasks)) > 0 + assert all(task in env.TRAINING_TASKS for task in tasks) + params.tasks = tasks + logger.info(f'Training tasks: {", ".join(tasks)}') + + return env diff --git a/symbolicregression/envs/encoders.py b/symbolicregression/envs/encoders.py new file mode 100644 index 0000000..03c6336 --- /dev/null +++ b/symbolicregression/envs/encoders.py @@ -0,0 +1,230 @@ +from abc import ABC, abstractmethod +import numpy as np +import math +from .generators import Node, NodeList +from .utils import * + + +class Encoder(ABC): + """ + Base class for encoders, encodes and decodes matrices + abstract methods for encoding/decoding numbers + """ + + def __init__(self, params): + pass + + @abstractmethod + def encode(self, val): + pass + + @abstractmethod + def decode(self, lst): + pass + +class GeneralEncoder: + def __init__(self, params, symbols, all_operators): + self.float_encoder = FloatSequences(params) + self.equation_encoder = Equation(params, symbols, self.float_encoder, all_operators) + +class FloatSequences(Encoder): + def __init__(self, params): + super().__init__(params) + self.float_precision = params.float_precision + self.mantissa_len = params.mantissa_len + self.max_exponent = params.max_exponent + self.base = (self.float_precision + 1) // self.mantissa_len + self.max_token = 10 ** self.base + self.symbols = ["+", "-"] + self.symbols.extend( + ["N" + f"%0{self.base}d" % i for i in range(self.max_token)] + ) + self.symbols.extend( + ["E" + str(i) for i in range(-self.max_exponent, self.max_exponent + 1)] + ) + + def encode(self, values): + """ + Write a float number + """ + precision = self.float_precision + + if len(values.shape) == 1: + seq = [] + value = values + for val in value: + assert val not in [-np.inf, np.inf] + sign = "+" if val >= 0 else "-" + m, e = (f"%.{precision}e" % val).split("e") + i, f = m.lstrip("-").split(".") + i = i + f + tokens = chunks(i, self.base) + expon = int(e) - precision + if expon < -self.max_exponent: + tokens = ["0" * self.base] * self.mantissa_len + expon = int(0) + seq.extend([sign, *["N" + token for token in tokens], "E" + str(expon)]) + return seq + else: + seqs = [self.encode(values[0])] + N = values.shape[0] + for n in range(1, N): + seqs += [self.encode(values[n])] + return seqs + + def decode(self, lst): + """ + Parse a list that starts with a float. + Return the float value, and the position it ends in the list. + """ + if len(lst) == 0: + return None + seq = [] + for val in chunks(lst, 2 + self.mantissa_len): + for x in val: + if x[0] not in ["-", "+", "E", "N"]: + return np.nan + try: + sign = 1 if val[0] == "+" else -1 + mant = "" + for x in val[1:-1]: + mant += x[1:] + mant = int(mant) + exp = int(val[-1][1:]) + value = sign * mant * (10 ** exp) + value = float(value) + except Exception: + value = np.nan + seq.append(value) + return seq + + +class Equation(Encoder): + def __init__(self, params, symbols, float_encoder, all_operators): + super().__init__(params) + self.params = params + self.max_int = self.params.max_int + self.symbols = symbols + if params.extra_unary_operators != "": + self.extra_unary_operators = self.params.extra_unary_operators.split(",") + else: + self.extra_unary_operators = [] + if params.extra_binary_operators != "": + self.extra_binary_operators = self.params.extra_binary_operators.split(",") + else: + self.extra_binary_operators = [] + self.float_encoder = float_encoder + self.all_operators=all_operators + + def encode(self, tree): + res = [] + for elem in tree.prefix().split(","): + try: + val = float(elem) + if elem.lstrip('-').isdigit(): + res.extend(self.write_int(int(elem))) + else: + res.extend(self.float_encoder.encode(np.array([val]))) + except ValueError: + res.append(elem) + return res + + def _decode(self, lst): + if len(lst) == 0: + return None, 0 + # elif (lst[0] not in self.symbols) and (not lst[0].lstrip("-").replace(".","").replace("e+", "").replace("e-","").isdigit()): + # return None, 0 + elif "OOD" in lst[0]: + return None, 0 + elif lst[0] in self.all_operators.keys(): + res = Node(lst[0], self.params) + arity = self.all_operators[lst[0]] + pos = 1 + for i in range(arity): + child, length = self._decode(lst[pos:]) + if child is None: + return None, pos + res.push_child(child) + pos += length + return res, pos + elif lst[0].startswith("INT"): + val, length = self.parse_int(lst) + return Node(str(val), self.params), length + elif lst[0]=="+" or lst[0]=="-": + try: + val = self.float_encoder.decode(lst[:3])[0] + except Exception as e: + #print(e, "error in encoding, lst: {}".format(lst)) + return None, 0 + return Node(str(val), self.params), 3 + elif lst[0].startswith("CONSTANT") or lst[0]=="y": ##added this manually CAREFUL!! + return Node(lst[0], self.params), 1 + elif lst[0] in self.symbols: + return Node(lst[0], self.params), 1 + else: + try: + float(lst[0]) #if number, return leaf + return Node(lst[0], self.params), 1 + except: + return None, 0 + + def split_at_value(self, lst, value): + indices = [i for i, x in enumerate(lst) if x == value] + res = [] + for start, end in zip( + [0, *[i + 1 for i in indices]], [*[i - 1 for i in indices], len(lst)] + ): + res.append(lst[start : end + 1]) + return res + + def decode(self, lst): + trees = [] + lists = self.split_at_value(lst, "|") + for lst in lists: + tree = self._decode(lst)[0] + if tree is None: + return None + trees.append(tree) + tree = NodeList(trees) + return tree + + def parse_int(self, lst): + """ + Parse a list that starts with an integer. + Return the integer value, and the position it ends in the list. + """ + base = self.max_int + val = 0 + i = 0 + for x in lst[1:]: + if not (x.rstrip("-").isdigit()): + break + val = val * base + int(x) + i += 1 + if base > 0 and lst[0] == "INT-": + val = -val + return val, i + 1 + + def write_int(self, val): + """ + Convert a decimal integer to a representation in the given base. + """ + if not self.params.use_sympy: + return [str(val)] + + base = self.max_int + res = [] + max_digit = abs(base) + neg = val < 0 + val = -val if neg else val + while True: + rem = val % base + val = val // base + if rem < 0 or rem > max_digit: + rem -= base + val += 1 + res.append(str(rem)) + if val == 0: + break + res.append("INT-" if neg else "INT+") + return res[::-1] diff --git a/symbolicregression/envs/environment.py b/symbolicregression/envs/environment.py new file mode 100644 index 0000000..eeeff23 --- /dev/null +++ b/symbolicregression/envs/environment.py @@ -0,0 +1,1046 @@ +# Copyright (c) 2020-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from distutils.log import INFO +from logging import getLogger +import os +import io +import sys +import copy +import json +import operator +from typing import Optional, List, Dict +from collections import deque, defaultdict +import time +import traceback + +# import math +import numpy as np +import symbolicregression.envs.encoders as encoders +import symbolicregression.envs.generators as generators +from symbolicregression.envs.generators import all_operators +import symbolicregression.envs.simplifiers as simplifiers +from typing import Optional, Dict +import torch +import torch.nn.functional as F +from torch.utils.data.dataset import Dataset +from torch.utils.data import DataLoader +import collections +from .utils import * +from ..utils import bool_flag, timeout, MyTimeoutError +import math +import scipy + +SPECIAL_WORDS = [ + "", + "", + "", + "", + "", + "", + "", + "", + "", + "(", + ")", + "SPECIAL", + "OOD_unary_op", + "OOD_binary_op", + "OOD_constant", +] +logger = getLogger() + +SKIP_ITEM = "SKIP_ITEM" + +class FunctionEnvironment(object): + + TRAINING_TASKS = {"functions"} + + def __init__(self, params): + self.params = params + self.rng=None + self.float_precision = params.float_precision + self.mantissa_len = params.mantissa_len + self.max_size = None + self.float_tolerance = 10 ** (-params.float_precision) + self.additional_tolerance = [ + 10 ** (-i) for i in range(params.float_precision + 1) + ] + assert ( + params.float_precision + 1 + ) % params.mantissa_len == 0, "Bad precision/mantissa len ratio" + + self.generator = generators.RandomFunctions(params, SPECIAL_WORDS) + self.float_encoder = self.generator.float_encoder + self.float_words = self.generator.float_words + self.equation_encoder = self.generator.equation_encoder + self.equation_words = self.generator.equation_words + self.equation_words += self.float_words + + self.simplifier = simplifiers.Simplifier(self.generator) + + # number of words / indices + self.float_id2word = {i: s for i, s in enumerate(self.float_words)} + self.equation_id2word = {i: s for i, s in enumerate(self.equation_words)} + self.float_word2id = {s: i for i, s in self.float_id2word.items()} + self.equation_word2id = {s: i for i, s in self.equation_id2word.items()} + + for ood_unary_op in self.generator.extra_unary_operators: + self.equation_word2id[ood_unary_op] = self.equation_word2id["OOD_unary_op"] + for ood_binary_op in self.generator.extra_binary_operators: + self.equation_word2id[ood_binary_op] = self.equation_word2id[ + "OOD_binary_op" + ] + if self.generator.extra_constants is not None: + for c in self.generator.extra_constants: + self.equation_word2id[c] = self.equation_word2id["OOD_constant"] + + assert len(self.float_words) == len(set(self.float_words)) + assert len(self.equation_word2id) == len(set(self.equation_word2id)) + self.n_words = params.n_words = len(self.equation_words) + logger.info( + f"vocabulary: {len(self.float_word2id)} float words, {len(self.equation_word2id)} equation words" + ) + + def mask_from_seperator(self, x, sep): + sep_id = self.float_word2id[sep] + alen = ( + torch.arange(x.shape[0], dtype=torch.long, device=x.device) + .unsqueeze(-1) + .repeat(1, x.shape[1]) + ) + sep_id_occurence = torch.tensor( + [ + torch.where(x[:, i] == sep_id)[0][0].item() + if len(torch.where(x[:, i] == sep_id)[0]) > 0 + else -1 + for i in range(x.shape[1]) + ] + ) + mask = alen > sep_id_occurence + return mask + + def batch_equations(self, equations): + """ + Take as input a list of n sequences (torch.LongTensor vectors) and return + a tensor of size (slen, n) where slen is the length of the longest + sentence, and a vector lengths containing the length of each sentence. + """ + lengths = torch.LongTensor([2 + len(eq) for eq in equations]) + sent = torch.LongTensor(lengths.max().item(), lengths.size(0)).fill_( + self.float_word2id[""] + ) + sent[0] = self.equation_word2id[""] + for i, eq in enumerate(equations): + sent[1 : lengths[i] - 1, i].copy_(eq) + sent[lengths[i] - 1, i] = self.equation_word2id[""] + return sent, lengths + + def word_to_idx(self, words, float_input=True): + if float_input: + return [ + [ + torch.LongTensor([self.float_word2id[dim] for dim in point]) + for point in seq + ] + for seq in words + ] + else: + return [ + torch.LongTensor([self.equation_word2id[w] for w in eq]) for eq in words + ] + + def word_to_infix(self, words, is_float=True, str_array=True): + if is_float: + m = self.float_encoder.decode(words) + if m is None: + return None + if str_array: + return np.array2string(np.array(m)) + else: + return np.array(m) + else: + m = self.equation_encoder.decode(words) + if m is None: + return None + if str_array: + return m.infix() + else: + return m + + def wrap_equation_floats(self, tree, constants): + prefix = tree.prefix().split(",") + j = 0 + for i, elem in enumerate(prefix): + if elem.startswith("CONSTANT"): + prefix[i] = str(constants[j]) + j += 1 + assert j == len(constants), "all constants were not fitted" + assert "CONSTANT" not in prefix, "tree {} got constant after wrapper {}".format(tree, constants) + tree_with_constants = self.word_to_infix(prefix, is_float=False, str_array=False) + return tree_with_constants + + def idx_to_infix(self, lst, is_float=True, str_array=True): + if is_float: + idx_to_words = [self.float_id2word[int(i)] for i in lst] + else: + idx_to_words = [self.equation_id2word[int(term)] for term in lst] + return self.word_to_infix(idx_to_words, is_float, str_array) + + def gen_expr(self, train, input_length_modulo=-1, nb_binary_ops=None, nb_unary_ops=None, input_dimension=None, output_dimension=None, n_input_points=None, input_distribution_type=None): + errors = defaultdict(int) + if not train or self.params.use_controller: + if nb_unary_ops is None: + nb_unary_ops = self.rng.randint(self.params.min_unary_ops, self.params.max_unary_ops+1) + if input_dimension is None: + input_dimension = self.rng.randint(self.params.min_input_dimension, self.params.max_input_dimension + 1) + while True: + try: + expr, error = self._gen_expr(train, input_length_modulo=input_length_modulo, nb_binary_ops=nb_binary_ops, nb_unary_ops=nb_unary_ops, input_dimension=input_dimension, output_dimension=output_dimension, n_input_points=n_input_points, input_distribution_type=input_distribution_type) + if error: + errors[error[0]]+=1 + assert False + return expr, errors + except: + if self.params.debug: + # print(expr['tree']) + # print(traceback.format_exc()) + pass + # print(error) + # self.errors["gen expr error"]+=1 + continue + + @timeout(1) + def _gen_expr(self, train, input_length_modulo=-1, nb_binary_ops=None, nb_unary_ops=None, input_dimension=None, output_dimension=None, n_input_points=None, input_distribution_type=None): + + tree, original_input_dimension, output_dimension, nb_unary_ops, nb_binary_ops = self.generator.generate_multi_dimensional_tree(rng=self.rng, nb_unary_ops=nb_unary_ops, nb_binary_ops=nb_binary_ops, input_dimension=input_dimension, output_dimension=output_dimension) + if tree is None: + return {"tree": tree}, ["bad tree"] + sum_binary_ops=max(nb_binary_ops) + sum_unary_ops=max(nb_unary_ops) + sum_ops=sum_binary_ops+sum_unary_ops + input_dimension = self.generator.relabel_variables(tree) + if input_dimension==0 or (self.params.enforce_dim and original_input_dimension > input_dimension): + return {"tree": tree}, ["bad input dimension"] + + for op in self.params.operators_to_not_repeat.split(','): + if op and tree.prefix().count(op)>1: + return {"tree": tree}, ["ops repeated"] + + if self.params.use_sympy: + len_before = len(tree.prefix().split(',')) + tree = ( + self.simplifier.simplify_tree(tree) if self.params.use_sympy else tree + ) + len_after = len(tree.prefix().split(',')) + if tree is None or len_after > 2*len_before : + return {"tree": tree}, ["simplification error"] + + dimensions = { + "input_dimension": input_dimension, + "output_dimension": output_dimension, + } + if n_input_points is None: + n_input_points = ( + self.params.max_len + if not train + else self.rng.randint(min(self.params.min_len_per_dim * input_dimension, self.params.max_len), self.params.max_len + 1) + ) + + if train: n_prediction_points=0 + else: n_prediction_points=self.params.n_prediction_points + + input_distribution_type_to_int = {'gaussian': 0, 'uniform': 1} + if input_distribution_type is None: + input_distribution_type = 'gaussian' if self.rng.random()<.5 else 'uniform' + n_centroids = self.rng.randint(1, self.params.max_centroids) + + if self.params.prediction_sigmas is None: + prediction_sigmas = [] + else: + prediction_sigmas = [float(sigma) for sigma in self.params.prediction_sigmas.split(",")] + + tree, datapoints = self.generator.generate_datapoints( + tree=tree, + rng=self.rng, + input_dimension=dimensions["input_dimension"], + n_input_points=n_input_points, + n_prediction_points=n_prediction_points, + prediction_sigmas = prediction_sigmas, + input_distribution_type=input_distribution_type, + n_centroids=n_centroids, + max_trials=self.params.max_trials, + ) + + if datapoints is None: + return {"tree": tree}, ["generation error"] + + x_to_fit, y_to_fit = datapoints["fit"] + predict_datapoints = copy.deepcopy(datapoints) + del predict_datapoints["fit"] + + all_outputs = np.concatenate([y for k, (x,y) in datapoints.items()]) + + ##output noise added to y_to_fit + try: + gamma = self.rng.uniform(0, self.params.train_noise_gamma) if train else self.params.eval_noise_gamma + norm = scipy.linalg.norm((np.abs(all_outputs) + 1e-100) / np.sqrt(all_outputs.shape[0])) + noise = gamma * norm * np.random.randn(*y_to_fit.shape) + y_to_fit += noise + except Exception as e: + print(e,"norm computation error" ) + return {"tree": tree}, ["norm computation error"] + + tree_encoded = self.equation_encoder.encode(tree) + skeleton_tree, _ = self.generator.function_to_skeleton(tree) + skeleton_tree_encoded = self.equation_encoder.encode(skeleton_tree) + + assert all([x in self.equation_word2id for x in tree_encoded]), "tree: {}\n encoded: {}".format(tree, tree_encoded) + + if input_length_modulo != -1 and not train: + indexes_to_keep = np.arange(min(input_length_modulo, self.params.max_len), self.params.max_len+1, step=input_length_modulo) + else: + indexes_to_keep = [n_input_points] + + X_to_fit, Y_to_fit = [], [] + info = {"n_input_points": [], "n_unary_ops": [], "n_binary_ops": [], "d_in": [], "d_out": [], "input_distribution_type": [], "n_centroids": []} + n_input_points = x_to_fit.shape[0] + + for idx in indexes_to_keep: + _x_to_fit = x_to_fit[:idx] if idx > 0 else x_to_fit + _y_to_fit = y_to_fit[:idx] if idx > 0 else y_to_fit + X_to_fit.append(_x_to_fit) + Y_to_fit.append(_y_to_fit) + info["n_input_points"].append(idx) + info["n_unary_ops"].append(sum(nb_unary_ops)) + info["n_binary_ops"].append(sum(nb_binary_ops)) + info["d_in"].append(dimensions["input_dimension"]) + info["d_out"].append(dimensions["output_dimension"]) + info["input_distribution_type"].append(input_distribution_type_to_int[input_distribution_type]) + info["n_centroids"].append(n_centroids) + + expr = { + "X_to_fit": X_to_fit, + "Y_to_fit": Y_to_fit, + "tree_encoded": tree_encoded, + "skeleton_tree_encoded": skeleton_tree_encoded, + "tree": tree, + "skeleton_tree": skeleton_tree, + "infos": info, + } + for k, (x,y) in predict_datapoints.items(): + expr["x_to_" + k]= x + expr["y_to_" + k]= y + return expr, [] + + + def create_train_iterator(self, task, data_path, params, **args): + """ + Create a dataset for this environment. + """ + logger.info(f"Creating train iterator for {task} ...") + dataset = EnvDataset( + self, + task, + train=True, + skip=self.params.queue_strategy is not None, + params=params, + path=(None if data_path is None else data_path[task][0]), + **args, + + ) + + if self.params.queue_strategy is None: collate_fn=dataset.collate_fn + else: + collate_fn = dataset.collate_reduce_padding( + dataset.collate_fn, + key_fn=lambda x: x["infos"]["input_sequence_length"]+ len(x["tree_encoded"]),# (x["infos"]["input_sequence_length"], len(x["tree_encoded"])), + max_size=self.max_size, + ) + return DataLoader( + dataset, + timeout=(0 if params.num_workers == 0 else 3600), + batch_size=params.batch_size, + num_workers=( + params.num_workers + if data_path is None or params.num_workers == 0 + else 1 + ), + shuffle=False, + collate_fn=collate_fn, + ) + + def create_test_iterator( + self, + data_type, + task, + data_path, + batch_size, + params, + size, + input_length_modulo, + **args, + ): + """ + Create a dataset for this environment. + """ + logger.info(f"Creating {data_type} iterator for {task} ...") + + dataset = EnvDataset( + self, + task, + train=False, + skip=False, + params=params, + path=(None if data_path is None else data_path[task][int(data_type[5:])]), + size=size, + type=data_type, + input_length_modulo=input_length_modulo, + **args, + ) + + return DataLoader( + dataset, + timeout=0, + batch_size=batch_size, + num_workers=1, + shuffle=False, + collate_fn=dataset.collate_fn, + ) + + @staticmethod + def register_args(parser): + """ + Register environment parameters. + """ + parser.add_argument( + "--queue_strategy", + type=str, + default="uniform_sampling", + help="in [precompute_batches, uniform_sampling, uniform_sampling_replacement]", + ) + + parser.add_argument("--collate_queue_size", type=int, default=2000) + + parser.add_argument( + "--use_sympy", + type=bool_flag, + default=False, + help="Whether to use sympy parsing (basic simplification)", + ) + parser.add_argument( + "--simplify", + type=bool_flag, + default=False, + help="Whether to use further sympy simplification", + ) + parser.add_argument( + "--use_abs", + type=bool_flag, + default=False, + help="Whether to replace log and sqrt by log(abs) and sqrt(abs)", + ) + + # encoding + parser.add_argument( + "--operators_to_downsample", + type=str, + default="div_0,arcsin_0,arccos_0,tan_0.2,arctan_0.2,sqrt_5,pow2_3,inv_3", + help="Which operator to remove", + ) + parser.add_argument( + "--operators_to_not_repeat", + type=str, + default="", + help="Which operator to not repeat", + ) + + parser.add_argument( + "--max_unary_depth", + type=int, + default=6, + help="Max number of operators inside unary", + ) + + parser.add_argument( + "--required_operators", + type=str, + default="", + help="Which operator to remove", + ) + parser.add_argument( + "--extra_unary_operators", + type=str, + default="", + help="Extra unary operator to add to data generation", + ) + parser.add_argument( + "--extra_binary_operators", + type=str, + default="", + help="Extra binary operator to add to data generation", + ) + parser.add_argument( + "--extra_constants", + type=str, + default=None, + help="Additional int constants floats instead of ints", + ) + + parser.add_argument("--min_input_dimension", type=int, default=1) + parser.add_argument("--max_input_dimension", type=int, default=10) + parser.add_argument("--min_output_dimension", type=int, default=1) + parser.add_argument("--max_output_dimension", type=int, default=1) + parser.add_argument( + "--enforce_dim", + type=bool, + default=True, + help="should we enforce that we get as many examples of each dim ?" + ) + + parser.add_argument( + "--use_controller", + type=bool, + default=True, + help="should we enforce that we get as many examples of each dim ?" + ) + + + parser.add_argument( + "--float_precision", + type=int, + default=3, + help="Number of digits in the mantissa", + ) + parser.add_argument( + "--mantissa_len", + type=int, + default=1, + help="Number of tokens for the mantissa (must be a divisor or float_precision+1)", + ) + parser.add_argument( + "--max_exponent", type=int, default=100, help="Maximal order of magnitude" + ) + parser.add_argument( + "--max_exponent_prefactor", type=int, default=1, help="Maximal order of magnitude in prefactors" + ) + parser.add_argument( + "--max_token_len", + type=int, + default=0, + help="max size of tokenized sentences, 0 is no filtering", + ) + parser.add_argument( + "--tokens_per_batch", + type=int, + default=10000, + help="max number of tokens per batch", + ) + parser.add_argument( + "--pad_to_max_dim", + type=bool, + default=True, + help="should we pad inputs to the maximum dimension?", + ) + + # generator + parser.add_argument( + "--max_int", + type=int, + default=10, + help="Maximal integer in symbolic expressions", + ) + parser.add_argument( + "--min_binary_ops_per_dim", + type=int, + default=0, + help="Min number of binary operators per input dimension", + ) + parser.add_argument( + "--max_binary_ops_per_dim", + type=int, + default=1, + help="Max number of binary operators per input dimension", + ) + parser.add_argument( + "--max_binary_ops_offset", + type=int, + default=4, + help="Offset for max number of binary operators", + ) + parser.add_argument( + "--min_unary_ops", + type=int, + default=0, + help="Min number of unary operators" + ) + parser.add_argument( + "--max_unary_ops", + type=int, + default=4, + help="Max number of unary operators", + ) + parser.add_argument( + "--min_op_prob", + type=float, + default=0.01, + help="Minimum probability of generating an example with given n_op, for our curriculum strategy", + ) + parser.add_argument( + "--max_len", type=int, default=200, help="Max number of terms in the series" + ) + parser.add_argument( + "--min_len_per_dim", type=int, default=5, help="Min number of terms per dim" + ) + parser.add_argument( + "--max_centroids", type=int, default=10, help="Max number of centroids for the input distribution" + ) + + parser.add_argument( + "--prob_const", + type=float, + default=0., + help="Probability to generate integer in leafs", + ) + + parser.add_argument( + "--reduce_num_constants", + type=bool, + default=True, + help="Use minimal amount of constants in eqs" + ) + + parser.add_argument( + "--use_skeleton", + type=bool, + default=False, + help="should we use a skeleton rather than functions with constants", + ) + + parser.add_argument( + "--prob_rand", + type=float, + default=0.0, + help="Probability to generate n in leafs", + ) + parser.add_argument( + "--max_trials", + type=int, + default=1, + help="How many trials we have for a given function", + ) + + # evaluation + parser.add_argument( + "--n_prediction_points", + type=int, + default=200, + help="number of next terms to predict", + ) + + parser.add_argument( + "--prediction_sigmas", + type=str, + default='1,2,4,8,16', + help="sigmas value for generation predicts" + ) + +class EnvDataset(Dataset): + def __init__( + self, + env, + task, + train, + params, + path, + skip=False, + size=None, + type=None, + input_length_modulo=-1, + **args, + ): + super(EnvDataset).__init__() + self.env = env + self.train = train + self.skip=skip + self.task = task + self.batch_size = params.batch_size + self.env_base_seed = params.env_base_seed + self.path = path + self.count = 0 + self.remaining_data = 0 + self.type = type + self.input_length_modulo = input_length_modulo + self.params = params + self.errors = defaultdict(int) + + if "test_env_seed" in args: + self.test_env_seed = args["test_env_seed"] + else: + self.test_env_seed = None + if "env_info" in args: + self.env_info = args["env_info"] + else: + self.env_info = None + + assert task in FunctionEnvironment.TRAINING_TASKS + assert size is None or not self.train + assert not params.batch_load or params.reload_size > 0 + # batching + self.num_workers = params.num_workers + self.batch_size = params.batch_size + + self.batch_load = params.batch_load + self.reload_size = params.reload_size + self.local_rank = params.local_rank + + self.basepos = 0 + self.nextpos = 0 + self.seekpos = 0 + + self.collate_queue: Optional[List] = [] if self.train else None + self.collate_queue_size = params.collate_queue_size + self.tokens_per_batch = params.tokens_per_batch + + # generation, or reloading from file + if path is not None: + assert os.path.isfile(path), "{} not found".format(path) + if params.batch_load and self.train: + self.load_chunk() + else: + logger.info(f"Loading data from {path} ...") + with io.open(path, mode="r", encoding="utf-8") as f: + # either reload the entire file, or the first N lines + # (for the training set) + if not train: + lines = [] + for i, line in enumerate(f): + lines.append(json.loads(line.rstrip())) + else: + lines = [] + for i, line in enumerate(f): + if i == params.reload_size: + break + if i % params.n_gpu_per_node == params.local_rank: + # lines.append(line.rstrip()) + lines.append(json.loads(line.rstrip())) + # self.data = [xy.split("=") for xy in lines] + # self.data = [xy for xy in self.data if len(xy) == 3] + self.data = lines + logger.info(f"Loaded {len(self.data)} equations from the disk.") + + # dataset size: infinite iterator for train, finite for valid / test + # (default of 10000 if no file provided) + if self.train: + self.size = 1 << 60 + elif size is None: + self.size = 10000 if path is None else len(self.data) + else: + assert size > 0 + self.size = size + + def collate_size_fn(self, batch: Dict) -> int: + if len(batch) == 0: + return 0 + return len(batch) * max( + [seq["infos"]["input_sequence_length"] for seq in batch] + ) + + def load_chunk(self): + self.basepos = self.nextpos + logger.info( + f"Loading data from {self.path} ... seekpos {self.seekpos}, " + f"basepos {self.basepos}" + ) + endfile = False + with io.open(self.path, mode="r", encoding="utf-8") as f: + f.seek(self.seekpos, 0) + lines = [] + for i in range(self.reload_size): + line = f.readline() + if not line: + endfile = True + break + if i % self.params.n_gpu_per_node == self.local_rank: + lines.append(line.rstrip().split("|")) + self.seekpos = 0 if endfile else f.tell() + + self.data = [xy.split("\t") for _, xy in lines] + self.data = [xy for xy in self.data if len(xy) == 2] + self.nextpos = self.basepos + len(self.data) + logger.info( + f"Loaded {len(self.data)} equations from the disk. seekpos {self.seekpos}, " + f"nextpos {self.nextpos}" + ) + if len(self.data) == 0: + self.load_chunk() + + def collate_reduce_padding(self, collate_fn, key_fn, max_size=None): + if self.params.queue_strategy == None: return collate_fn + + f = self.collate_reduce_padding_uniform + def wrapper(b): + try: + return f( + collate_fn=collate_fn, + key_fn=key_fn, + max_size=max_size, + )(b) + except ZMQNotReady: + return ZMQNotReadySample() + + return wrapper + + def _fill_queue(self, n: int, key_fn): + """ + Add elements to the queue (fill it entirely if `n == -1`) + Optionally sort it (if `key_fn` is not `None`) + Compute statistics + """ + assert self.train, "Not Implemented" + assert len(self.collate_queue) <= self.collate_queue_size, "Problem with queue size" + + # number of elements to add + n = self.collate_queue_size - len(self.collate_queue) if n == -1 else n + assert n > 0, "n<=0" + + for _ in range(n): + if self.path is None: + sample = self.generate_sample() + else: + ##TODO + assert False, "need to finish implementing load dataset, but do not know how to handle read index" + sample = self.read_sample(index) + self.collate_queue.append(sample) + + # sort sequences + if key_fn is not None: + self.collate_queue.sort(key=key_fn) + + def collate_reduce_padding_uniform(self, collate_fn, key_fn, max_size=None): + """ + Stores a queue of COLLATE_QUEUE_SIZE candidates (created with warm-up). + When collating, insert into the queue then sort by key_fn. + Return a random range in collate_queue. + @param collate_fn: the final collate function to be used + @param key_fn: how elements should be sorted (input is an item) + @param size_fn: if a target batch size is wanted, function to compute the size (input is a batch) + @param max_size: if not None, overwrite params.batch.tokens + @return: a wrapped collate_fn + """ + + def wrapped_collate(sequences: List): + + if not self.train: + return collate_fn(sequences) + + # fill queue + + assert all(seq == SKIP_ITEM for seq in sequences) + assert len(self.collate_queue) < self.collate_queue_size, "Queue size too big, current queue size ({}/{})".format(len(self.collate_queue), self.collate_queue_size) + self._fill_queue(n=-1, key_fn=key_fn) + assert len(self.collate_queue) == self.collate_queue_size, "Fill has not been successful" + + # select random index + before = self.env.rng.randint(-self.batch_size, len(self.collate_queue)) + before = max(min(before, len(self.collate_queue) - self.batch_size), 0) + after = self.get_last_seq_id(before, max_size) + + # create batch / remove sampled sequences from the queue + to_ret = collate_fn(self.collate_queue[before:after]) + self.collate_queue = ( + self.collate_queue[:before] + self.collate_queue[after:] + ) + return to_ret + + return wrapped_collate + + + def get_last_seq_id(self, before: int, max_size: Optional[int]) -> int: + """ + Return the last sequence ID that would allow to fit according to `size_fn`. + """ + max_size = self.tokens_per_batch if max_size is None else max_size + + if max_size < 0: + after = before + self.batch_size + else: + after = before + while ( + after < len(self.collate_queue) + and self.collate_size_fn(self.collate_queue[before:after]) < max_size + ): + after += 1 + # if we exceed `tokens_per_batch`, remove the last element + size = self.collate_size_fn(self.collate_queue[before:after]) + if size > max_size: + if after > before + 1: + after -= 1 + else: + logger.warning( + f"Exceeding tokens_per_batch: {size} " + f"({after - before} sequences)" + ) + return after + + def collate_fn(self, elements): + """ + Collate samples into a batch. + """ + + samples = zip_dic(elements) + info_tensor = { + info_type: torch.LongTensor(samples["infos"][info_type]) + for info_type in samples["infos"].keys() + } + samples["infos"] = info_tensor + if "input_sequence_length" in samples["infos"]: + del samples["infos"]["input_sequence_length"] + errors = copy.deepcopy(self.errors) + self.errors = defaultdict(int) + return samples, errors + + def init_rng(self): + """ + Initialize random generator for training. + """ + if self.env.rng is not None: + return + if self.train: + worker_id = self.get_worker_id() + self.env.worker_id = worker_id + seed = [worker_id, self.params.global_rank, self.env_base_seed] + if self.env_info is not None: + seed += [self.env_info] + self.env.rng = np.random.RandomState(seed) + logger.info( + f"Initialized random generator for worker {worker_id}, with seed " + f"{seed} " + f"(base seed={self.env_base_seed})." + ) + else: + worker_id = self.get_worker_id() + self.env.worker_id = worker_id + seed = [worker_id, self.params.global_rank, self.test_env_seed if "valid" in self.type else 0] + self.env.rng = np.random.RandomState(seed) + logger.info( + "Initialized {} generator, with seed {} (random state: {})".format( + self.type, seed, self.env.rng + ) + ) + + def get_worker_id(self): + """ + Get worker ID. + """ + if not self.train: + return 0 + worker_info = torch.utils.data.get_worker_info() + assert (worker_info is None) == (self.num_workers == 0), "issue in worker id" + return 0 if worker_info is None else worker_info.id + + def __len__(self): + """ + Return dataset size. + """ + return self.size + + def __getitem__(self, index): + """ + Return a training sample. + Either generate it, or read it from file. + """ + self.init_rng() + if self.path is None: + if self.train and self.skip: + return SKIP_ITEM + else: + sample = self.generate_sample() + return sample + else: + if self.train and self.skip: + return SKIP_ITEM + else: + return self.read_sample(index) + + def read_sample(self, index): + """ + Read a sample. + """ + idx = index + if self.train: + if self.batch_load: + if index >= self.nextpos: + self.load_chunk() + idx = index - self.basepos + else: + index = self.env.rng.randint(len(self.data)) + idx = index + + def str_list_to_float_array(lst): + for i in range(len(lst)): + for j in range(len(lst[i])): + lst[i][j]=float(lst[i][j]) + return np.array(lst) + + x = copy.deepcopy(self.data[idx]) + x["x_to_fit"]=str_list_to_float_array(x["x_to_fit"]) + x["y_to_fit"]=str_list_to_float_array(x["y_to_fit"]) + x["x_to_predict"]=str_list_to_float_array(x["x_to_predict"]) + x["y_to_predict"]=str_list_to_float_array(x["y_to_predict"]) + x["tree"] = self.env.equation_encoder.decode(x["tree"].split(",")) + x["tree_encoded"] = self.env.equation_encoder.encode(x["tree"]) + infos = {} + + for col in x.keys(): + if col not in ["x_to_fit", "y_to_fit", "x_to_predict", "y_to_predict", "tree", "tree_encoded"]: + infos[col]=int(x[col]) + x["infos"]=infos + for k in infos.keys(): del x[k] + return x + + def generate_sample(self): + """ + Generate a sample. + """ + + if self.remaining_data == 0: + self.expr, errors = self.env.gen_expr( + self.train, + input_length_modulo=self.input_length_modulo, + ) + for error, count in errors.items(): + self.errors[error] += count + + self.remaining_data = len(self.expr["X_to_fit"]) + + self.remaining_data -= 1 + x_to_fit = self.expr["X_to_fit"][-self.remaining_data] + y_to_fit = self.expr["Y_to_fit"][-self.remaining_data] + sample = copy.deepcopy(self.expr) + sample["x_to_fit"] = x_to_fit + sample["y_to_fit"] = y_to_fit + del sample["X_to_fit"] + del sample["Y_to_fit"] + sample["infos"] = select_dico_index(sample["infos"], -self.remaining_data) + sequence = [] + for n in range(sample["infos"]["n_input_points"]): + sequence.append([sample["x_to_fit"][n], sample["y_to_fit"][n]]) + sample["infos"]["input_sequence_length"] = self.env.get_length_after_batching( + [sequence] + )[0].item() + if sample["infos"]["input_sequence_length"] > self.params.tokens_per_batch: + # print(sample["infos"]["input_sequence_length"], self.params.tokens_per_batch) + return self.generate_sample() + self.count += 1 + return sample + +def select_dico_index(dico, idx): + new_dico = {} + for k in dico.keys(): + new_dico[k] = dico[k][idx] + return new_dico diff --git a/symbolicregression/envs/generators.py b/symbolicregression/envs/generators.py new file mode 100644 index 0000000..6692e28 --- /dev/null +++ b/symbolicregression/envs/generators.py @@ -0,0 +1,776 @@ +from abc import ABC, abstractmethod +from ast import parse +from operator import length_hint, xor + +# from turtle import degrees +import numpy as np +import math +import scipy.special +import copy +from logging import getLogger +import time +from numpy.compat.py3k import npy_load_module +from sympy import Min +from symbolicregression.envs import encoders +from collections import defaultdict +from scipy.stats import special_ortho_group + +logger = getLogger() +import random + +operators_real = { + "add": 2, + "sub": 2, + "mul": 2, + "div": 2, + "abs": 1, + "inv": 1, + "sqrt": 1, + "log": 1, + "exp": 1, + "sin": 1, + "arcsin": 1, + "cos": 1, + "arccos": 1, + "tan": 1, + "arctan": 1, + "pow2":1, + "pow3":1, +} + +operators_extra = {"pow": 2} + +math_constants = ["e", "pi", "euler_gamma", "CONSTANT"] +all_operators = {**operators_real, **operators_extra} + + +class Node: + def __init__(self, value, params, children=None): + self.value = value + self.children = children if children else [] + self.params = params + + def push_child(self, child): + self.children.append(child) + + def prefix(self): + s = str(self.value) + for c in self.children: + s += "," + c.prefix() + return s + + # export to latex qtree format: prefix with \Tree, use package qtree + def qtree_prefix(self): + s = "[.$" + str(self.value) + "$ " + for c in self.children: + s += c.qtree_prefix() + s += "]" + return s + + def infix(self): + nb_children = len(self.children) + if nb_children ==0: + if self.value.lstrip('-').isdigit(): + return str(self.value) + else: + #try: + # s = f"%.{self.params.float_precision}e" % float(self.value) + #except ValueError: + s = str(self.value) + return s + if nb_children == 1: + s = str(self.value) + if s == "pow2": + s = "(" + self.children[0].infix() + ")**2" + elif s == "pow3": + s = "(" + self.children[0].infix() + ")**3" + else: + s = s + "(" + self.children[0].infix() + ")" + return s + s = "(" + self.children[0].infix() + for c in self.children[1:]: + s = s + " " + str(self.value) + " " + c.infix() + return s + ")" + + def __len__(self): + lenc = 1 + for c in self.children: + lenc += len(c) + return lenc + + def __str__(self): + # infix a default print + return self.infix() + + def __repr__(self): + # infix a default print + return str(self) + + def val(self, x, deterministic=True): + if len(self.children) == 0: + if str(self.value).startswith("x_"): + _, dim = self.value.split("_") + dim = int(dim) + return x[:, dim] + elif str(self.value) == "rand": + if deterministic: + return np.zeros((x.shape[0],)) + return np.random.randn(x.shape[0]) + elif str(self.value) in math_constants: + return getattr(np, str(self.value))*np.ones((x.shape[0],)) + else: + return float(self.value)*np.ones((x.shape[0],)) + + if self.value == "add": + return self.children[0].val(x) + self.children[1].val(x) + if self.value == "sub": + return self.children[0].val(x) - self.children[1].val(x) + if self.value == "mul": + m1, m2 = self.children[0].val(x), self.children[1].val(x) + try: + return m1 * m2 + except Exception as e: + #print(e) + nans = np.empty((m1.shape[0],)) + nans[:]=np.nan + return nans + if self.value == "pow": + m1, m2 = self.children[0].val(x), self.children[1].val(x) + try: + return np.power(m1, m2) + except Exception as e: + #print(e) + nans = np.empty((m1.shape[0],)) + nans[:] = np.nan + return nans + if self.value == "max": + return np.maximum(self.children[0].val(x), self.children[1].val(x)) + if self.value == "min": + return np.minimum(self.children[0].val(x), self.children[1].val(x)) + + if self.value == "div": + denominator = self.children[1].val(x) + denominator[denominator == 0.] = np.nan + try: + return self.children[0].val(x) / denominator + except Exception as e: + #print(e) + nans = np.empty((denominator.shape[0],)) + nans[:]=np.nan + return nans + if self.value == "inv": + denominator = self.children[0].val(x) + denominator[denominator == 0.] = np.nan + try: + return 1 / denominator + except Exception as e: + #print(e) + nans = np.empty((denominator.shape[0],)) + nans[:] = np.nan + return nans + if self.value == "log": + numerator = self.children[0].val(x) + if self.params.use_abs: + numerator[numerator<=0.] *= -1 + else: + numerator[numerator<=0.]=np.nan + try: + return np.log(numerator) + except Exception as e: + #print(e) + nans = np.empty((numerator.shape[0],)) + nans[:] = np.nan + return nans + + if self.value == "sqrt": + numerator = self.children[0].val(x) + if self.params.use_abs: + numerator[numerator<=0.] *= -1 + else: + numerator[numerator<0.] = np.nan + try: + return np.sqrt(numerator) + except Exception as e: + #print(e) + nans = np.empty((numerator.shape[0],)) + nans[:] = np.nan + return nans + if self.value == "pow2": + numerator = self.children[0].val(x) + try: + return numerator**2 + except Exception as e: + #print(e) + nans = np.empty((numerator.shape[0],)) + nans[:] = np.nan + return nans + if self.value == "pow3": + numerator = self.children[0].val(x) + try: + return numerator**3 + except Exception as e: + #print(e) + nans = np.empty((numerator.shape[0],)) + nans[:] = np.nan + return nans + if self.value == "abs": + return np.abs(self.children[0].val(x)) + if self.value == "sign": + return (self.children[0].val(x) >= 0) * 2. - 1. + if self.value == "step": + x = self.children[0].val(x) + return x if x > 0 else 0 + if self.value == "id": + return self.children[0].val(x) + if self.value == "fresnel": + return scipy.special.fresnel(self.children[0].val(x))[0] + if self.value.startswith("eval"): + n = self.value[-1] + return getattr(scipy.special, self.value[:-1])(n, self.children[0].val(x))[ + 0 + ] + else: + fn = getattr(np, self.value, None) + if fn is not None: + try: + return fn(self.children[0].val(x)) + except Exception as e: + nans = np.empty((x.shape[0],)) + nans[:] = np.nan + return nans + fn = getattr(scipy.special, self.value, None) + if fn is not None: + return fn(self.children[0].val(x)) + assert False, "Could not find function" + + def get_recurrence_degree(self): + recurrence_degree = 0 + if len(self.children) == 0: + if str(self.value).startswith("x_"): + _, _, offset = self.value.split("_") + offset = int(offset) + if offset > recurrence_degree: + recurrence_degree = offset + return recurrence_degree + return max([child.get_recurrence_degree() for child in self.children]) + + def replace_node_value(self, old_value, new_value): + if self.value == old_value: + self.value = new_value + for child in self.children: + child.replace_node_value(old_value, new_value) + +class NodeList: + def __init__(self, nodes): + self.nodes = [] + for node in nodes: + self.nodes.append(node) + self.params = nodes[0].params + + def infix(self): + return " | ".join([node.infix() for node in self.nodes]) + + def __len__(self): + return sum([len(node) for node in self.nodes]) + + def prefix(self): + return ",|,".join([node.prefix() for node in self.nodes]) + + def __str__(self): + return self.infix() + + def __repr__(self): + return str(self) + + def val(self, xs, deterministic=True): + batch_vals = [np.expand_dims(node.val(np.copy(xs), deterministic=deterministic), -1) for node in self.nodes] + return np.concatenate(batch_vals, -1) + + def replace_node_value(self, old_value, new_value): + for node in self.nodes: + node.replace_node_value(old_value, new_value) + +class Generator(ABC): + def __init__(self, params): + pass + + @abstractmethod + def generate_datapoints(self, rng): + pass + + +class RandomFunctions(Generator): + def __init__(self, params, special_words): + super().__init__(params) + self.params = params + self.prob_const = params.prob_const + self.prob_rand = params.prob_rand + self.max_int = params.max_int + self.min_binary_ops_per_dim = params.min_binary_ops_per_dim + self.max_binary_ops_per_dim = params.max_binary_ops_per_dim + self.min_unary_ops = params.min_unary_ops + self.max_unary_ops = params.max_unary_ops + self.min_output_dimension = params.min_output_dimension + self.min_input_dimension = params.min_input_dimension + self.max_input_dimension = params.max_input_dimension + self.max_output_dimension = params.max_output_dimension + self.max_number = 10 ** (params.max_exponent + params.float_precision) + self.operators = copy.deepcopy(operators_real) + + self.operators_dowsample_ratio = defaultdict(float) + if params.operators_to_downsample != "": + for operator in self.params.operators_to_downsample.split(","): + operator, ratio = operator.split("_") + ratio = float(ratio) + self.operators_dowsample_ratio[operator]=ratio + + if params.required_operators != "": + self.required_operators = self.params.required_operators.split(",") + else: + self.required_operators = [] + + if params.extra_binary_operators != "": self.extra_binary_operators = self.params.extra_binary_operators.split(",") + else: self.extra_binary_operators = [] + if params.extra_unary_operators != "": self.extra_unary_operators = self.params.extra_unary_operators.split(",") + else: self.extra_unary_operators = [] + + self.unaries = [ + o for o in self.operators.keys() if np.abs(self.operators[o]) == 1 + ] + self.extra_unary_operators + + self.binaries = [ + o for o in self.operators.keys() if np.abs(self.operators[o]) == 2 + ] + self.extra_binary_operators + + unaries_probabilities = [] + for op in self.unaries: + if op not in self.operators_dowsample_ratio: + unaries_probabilities.append(1.0) + else: + ratio = self.operators_dowsample_ratio[op] + unaries_probabilities.append(ratio) + self.unaries_probabilities = np.array(unaries_probabilities) + self.unaries_probabilities /= self.unaries_probabilities.sum() + + + binaries_probabilities = [] + for op in self.binaries: + if op not in self.operators_dowsample_ratio: + binaries_probabilities.append(1.0) + else: + ratio = self.operators_dowsample_ratio[op] + binaries_probabilities.append(ratio) + self.binaries_probabilities = np.array(binaries_probabilities) + self.binaries_probabilities /= self.binaries_probabilities.sum() + + self.unary = False#len(self.unaries) > 0 + self.distrib = self.generate_dist(2 * self.max_binary_ops_per_dim * self.max_input_dimension) + + self.constants = [ + str(i) for i in range(-self.max_int, self.max_int + 1) if i != 0 + ] + self.constants += math_constants + self.variables = ["rand"] + [f"x_{i}" for i in range(self.max_input_dimension)] + self.symbols = ( + list(self.operators) + + self.constants + + self.variables + + ["|", "INT+", "INT-", "FLOAT+", "FLOAT-", "pow", "0"] + ) + self.constants.remove("CONSTANT") + + if self.params.extra_constants is not None: + self.extra_constants = self.params.extra_constants.split(",") + else: + self.extra_constants = [] + + self.general_encoder = encoders.GeneralEncoder(params, self.symbols, all_operators) + self.float_encoder = self.general_encoder.float_encoder + self.float_words = special_words + sorted(list(set(self.float_encoder.symbols))) + self.equation_encoder = self.general_encoder.equation_encoder + self.equation_words = sorted(list(set(self.symbols))) + self.equation_words = special_words + self.equation_words + + def generate_dist(self, max_ops): + """ + `max_ops`: maximum number of operators + Enumerate the number of possible unary-binary trees that can be generated from empty nodes. + D[e][n] represents the number of different binary trees with n nodes that + can be generated from e empty nodes, using the following recursion: + D(n, 0) = 0 + D(0, e) = 1 + D(n, e) = D(n, e - 1) + p_1 * D(n- 1, e) + D(n - 1, e + 1) + p1 = if binary trees, 1 if unary binary + """ + p1 = 1 if self.unary else 0 + # enumerate possible trees + D = [] + D.append([0] + ([1 for i in range(1, 2 * max_ops + 1)])) + for n in range(1, 2 * max_ops + 1): # number of operators + s = [0] + for e in range(1, 2 * max_ops - n + 1): # number of empty nodes + s.append(s[e - 1] + p1 * D[n - 1][e] + D[n - 1][e + 1]) + D.append(s) + assert all(len(D[i]) >= len(D[i + 1]) for i in range(len(D) - 1)), "issue in generate_dist" + return D + + + def generate_float(self, rng, exponent=None): + sign = rng.choice([-1,1]) + mantissa = float(rng.choice(range(1,10**self.params.float_precision))) + min_power = -self.params.max_exponent_prefactor-(self.params.float_precision+1)//2 + max_power = self.params.max_exponent_prefactor-(self.params.float_precision+1)//2 + if not exponent: + exponent = rng.randint(min_power, max_power+1) + constant = sign*(mantissa*10**exponent) + return str(constant) + + def generate_int(self, rng): + return str(rng.choice(self.constants + self.extra_constants)) + + def generate_leaf(self, rng, input_dimension): + if rng.rand() < self.prob_rand: + return "rand" + else: + if self.n_used_dims < input_dimension: + dimension = self.n_used_dims + self.n_used_dims += 1 + return f"x_{dimension}" + else: + draw = rng.rand() + if draw < self.prob_const: + return self.generate_int(rng) + else: + dimension = rng.randint(0, input_dimension) + return f"x_{dimension}" + + def generate_ops(self, rng, arity): + if arity == 1: + ops = self.unaries + probas = self.unaries_probabilities + else: + ops = self.binaries + probas = self.binaries_probabilities + return rng.choice(ops, p=probas) + + def sample_next_pos(self, rng, nb_empty, nb_ops): + """ + Sample the position of the next node (binary case). + Sample a position in {0, ..., `nb_empty` - 1}. + """ + assert nb_empty > 0 + assert nb_ops > 0 + probs = [] + if self.unary: + for i in range(nb_empty): + probs.append(self.distrib[nb_ops - 1][nb_empty - i]) + for i in range(nb_empty): + probs.append(self.distrib[nb_ops - 1][nb_empty - i + 1]) + probs = [p / self.distrib[nb_ops][nb_empty] for p in probs] + probs = np.array(probs, dtype=np.float64) + e = rng.choice(len(probs), p=probs) + arity = 1 if self.unary and e < nb_empty else 2 + e %= nb_empty + return e, arity + + def generate_tree(self, rng, nb_binary_ops, input_dimension): + self.n_used_dims = 0 + tree = Node(0, self.params) + empty_nodes = [tree] + next_en = 0 + nb_empty = 1 + while nb_binary_ops > 0 : + next_pos, arity = self.sample_next_pos(rng, nb_empty, nb_binary_ops) + next_en += next_pos + op = self.generate_ops(rng, arity) + empty_nodes[next_en].value = op + for _ in range(arity): + e = Node(0, self.params) + empty_nodes[next_en].push_child(e) + empty_nodes.append(e) + next_en += 1 + nb_empty += arity - 1 - next_pos + nb_binary_ops -= 1 + rng.shuffle(empty_nodes) + for n in empty_nodes: + if len(n.children)==0: + n.value = self.generate_leaf(rng, input_dimension) + return tree + + def generate_multi_dimensional_tree(self, rng, input_dimension=None, output_dimension=None, nb_unary_ops=None, nb_binary_ops=None): + trees = [] + + if input_dimension is None: + input_dimension = rng.randint( + self.min_input_dimension, self.max_input_dimension + 1) + if output_dimension is None: + output_dimension = rng.randint( + self.min_output_dimension, self.max_output_dimension + 1 + ) + if nb_binary_ops is None: + min_binary_ops = self.min_binary_ops_per_dim * input_dimension + max_binary_ops = self.max_binary_ops_per_dim * input_dimension + nb_binary_ops_to_use = [rng.randint(min_binary_ops, self.params.max_binary_ops_offset+max_binary_ops) for dim in range(output_dimension)] + elif isinstance(nb_binary_ops, int): + nb_binary_ops_to_use = [nb_binary_ops for _ in range(output_dimension)] + else: + nb_binary_ops_to_use = nb_binary_ops + if nb_unary_ops is None: + nb_unary_ops_to_use = [rng.randint(self.min_unary_ops, self.max_unary_ops+1) for dim in range(output_dimension)] + elif isinstance(nb_unary_ops, int): + nb_unary_ops_to_use = [nb_unary_ops for _ in range(output_dimension)] + else: + nb_unary_ops_to_use = nb_unary_ops + + for i in range(output_dimension): + tree = self.generate_tree(rng, nb_binary_ops_to_use[i], input_dimension) + tree = self.add_unaries(rng, tree, nb_unary_ops_to_use[i]) + ##Adding constants + if self.params.reduce_num_constants: + tree = self.add_prefactors(rng, tree) + else: + tree = self.add_linear_transformations(rng, tree, target=self.variables) + tree = self.add_linear_transformations(rng, tree, target=self.unaries) + trees.append(tree) + tree = NodeList(trees) + + nb_unary_ops_to_use = [len([x for x in tree_i.prefix().split(",") if x in self.unaries]) for tree_i in tree.nodes] + nb_binary_ops_to_use = [len([x for x in tree_i.prefix().split(",") if x in self.binaries]) for tree_i in tree.nodes] + + for op in self.required_operators: + if op not in tree.infix(): + return self.generate_multi_dimensional_tree(rng, input_dimension, output_dimension, nb_unary_ops, nb_binary_ops) + + return tree, input_dimension, output_dimension, nb_unary_ops_to_use, nb_binary_ops_to_use + + def add_unaries(self, rng, tree, nb_unaries): + prefix = self._add_unaries(rng,tree) + prefix = prefix.split(',') + indices = [] + for i, x in enumerate(prefix): + if x in self.unaries: + indices.append(i) + rng.shuffle(indices) + if len(indices)>nb_unaries: + to_remove = indices[:len(indices)-nb_unaries] + for index in sorted(to_remove, reverse=True): + del prefix[index] + tree = self.equation_encoder.decode(prefix).nodes[0] + return tree + + def _add_unaries(self, rng, tree): + + s = str(tree.value) + + for c in tree.children: + if len(c.prefix().split(','))0 and prefix[i-1] in self.unaries: + del prefix[i-1] + try: + value = float(pre) + except: + value = getattr(np, pre) + constants.append(value) + j+=1 + else: + continue + + new_tree = self.equation_encoder.decode(prefix) + return new_tree, constants + + def wrap_equation_floats(self, tree, constants): + tree=self.tree + env=self.env + prefix = tree.prefix().split(",") + j = 0 + for i, elem in enumerate(prefix): + if elem.startswith("CONSTANT"): + prefix[i] = str(constants[j]) + j += 1 + assert j == len(constants), "all constants were not fitted" + assert "CONSTANT" not in prefix, "tree {} got constant after wrapper {}".format(tree, constants) + tree_with_constants = env.word_to_infix(prefix, is_float=False, str_array=False) + return tree_with_constants + + def order_datapoints(self, inputs, outputs): + mean_input = inputs.mean(0) + distance_to_mean = np.linalg.norm(inputs-mean_input, axis=-1) + order_by_distance = np.argsort(distance_to_mean) + return inputs[order_by_distance], outputs[order_by_distance] + + def _generate_datapoints(self, tree, n_points, scale, rng, input_dimension, input_distribution_type, n_centroids, max_trials, rotate=True, offset=None): + inputs, outputs = [], [] + remaining_points = n_points + trials = 0 + + means = rng.randn(n_centroids, input_dimension,) + covariances = rng.uniform(0,1,size=(n_centroids, input_dimension)) + if rotate: + rotations = [special_ortho_group.rvs(input_dimension) if input_dimension>1 else np.identity(1) for i in range(n_centroids)] + else: + rotations = [np.identity(input_dimension) for i in range(n_centroids)] + + weights = rng.uniform(0,1,size=(n_centroids,)) + weights /= np.sum(weights) + n_points_comp = rng.multinomial(n_points, weights) + + while remaining_points > 0 and trials= self.max_number] = np.nan + output[np.abs(output) == np.inf] = np.nan + is_nan_idx = np.any(np.isnan(output), -1) + input = input[~is_nan_idx, :] + output = output[~is_nan_idx, :] + + valid_points = output.shape[0] + trials += 1 + remaining_points -= valid_points + if valid_points==0: + continue + inputs.append(input) + outputs.append(output) + + if remaining_points > 0: + return None, None + + inputs = np.concatenate(inputs, 0)[:n_points] + outputs = np.concatenate(outputs, 0)[:n_points] + return inputs, outputs + + def generate_datapoints(self, tree, n_input_points, n_prediction_points, prediction_sigmas, rotate=True, offset=None,**kwargs): + inputs, outputs = self._generate_datapoints(tree=tree, n_points=n_input_points, scale=1, rotate=rotate, offset=offset, **kwargs) + + if inputs is None: + return None, None + datapoints = {"fit": (inputs, outputs)} + + if n_prediction_points==0: return tree, datapoints + for sigma_factor in prediction_sigmas: + inputs, outputs = self._generate_datapoints(tree=tree, n_points=n_prediction_points, scale=sigma_factor, rotate=rotate, offset=offset,**kwargs) + if inputs is None: + return None, None + datapoints["predict_{}".format(sigma_factor)] = (inputs, outputs) + + return tree, datapoints + + +if __name__ == "__main__": + + from parsers import get_parser + from symbolicregression.envs.environment import SPECIAL_WORDS + + parser = get_parser() + params = parser.parse_args() + generator = RandomFunctions(params, SPECIAL_WORDS) + rng = np.random.RandomState(0) + tree, _, _, _, _ = generator.generate_multi_dimensional_tree(np.random.RandomState(0), input_dimension=1) + print(tree) + x, y = generator.generate_datapoints(rng, tree, "gaussian", 10, 200, 200) + generator.order_datapoints(x,y) diff --git a/symbolicregression/envs/simplifiers.py b/symbolicregression/envs/simplifiers.py new file mode 100644 index 0000000..6ccbea7 --- /dev/null +++ b/symbolicregression/envs/simplifiers.py @@ -0,0 +1,356 @@ +import traceback +import sympy as sp +from sympy.parsing.sympy_parser import parse_expr +from .generators import all_operators, math_constants, Node, NodeList +from sympy.core.rules import Transform +import numpy as np +from functools import partial +import numexpr as ne +import sympytorch +import torch +from ..utils import timeout, MyTimeoutError + +def simplify(f, seconds): + """ + Simplify an expression. + """ + assert seconds > 0 + @timeout(seconds) + def _simplify(f): + try: + f2 = sp.simplify(f) + if any(s.is_Dummy for s in f2.free_symbols): + return f + else: + return f2 + except MyTimeoutError: + return f + except Exception as e: + return f + return _simplify(f) + +class InvalidPrefixExpression(BaseException): + pass + +import signal +from contextlib import contextmanager + +@contextmanager +def timeout(time): + # Register a function to raise a TimeoutError on the signal. + signal.signal(signal.SIGALRM, raise_timeout) + # Schedule the signal to be sent after ``time``. + signal.alarm(time) + + try: + yield + except TimeoutError: + pass + finally: + # Unregister the signal so it won't be triggered + # if the timeout is not reached. + signal.signal(signal.SIGALRM, signal.SIG_IGN) + +def raise_timeout(signum, frame): + raise TimeoutError + +class Simplifier: + def __init__(self, generator): + + self.params = generator.params + self.encoder = generator.equation_encoder + self.operators = generator.operators + self.max_int = generator.max_int + self.local_dict = { + "n": sp.Symbol("n", real=True, nonzero=True, positive=True, integer=True), + "e": sp.E, + "pi": sp.pi, + "euler_gamma": sp.EulerGamma, + "arcsin": sp.asin, + "arccos": sp.acos, + "arctan": sp.atan, + "step": sp.Heaviside, + "sign": sp.sign, + } + for k in generator.variables: + self.local_dict[k] = sp.Symbol(k, real=True, integer=False) + + def expand_expr(self, expr): + with timeout(1): + expr = sp.expand(expr) + return expr + + def simplify_expr(self, expr): + with timeout(1): + expr = sp.simplify(expr) + return expr + + def tree_to_sympy_expr(self, tree): + prefix = tree.prefix().split(",") + sympy_compatible_infix = self.prefix_to_sympy_compatible_infix(prefix) + expr = parse_expr(sympy_compatible_infix, evaluate=True, local_dict=self.local_dict) + return expr + + def tree_to_torch_module(self, tree, dtype=torch.float32): + expr = self.tree_to_sympy_expr(tree) + mod = self.expr_to_torch_module(expr, dtype) + return mod + + def expr_to_torch_module(self, expr, dtype): + mod = sympytorch.SymPyModule(expressions=[expr]) + mod.to(dtype) + def wrapper_fn(_mod, x, constants=None): + local_dict = {} + for d in range(x.shape[1]): + local_dict["x_{}".format(d)]=x[:, d] + if constants is not None: + for d in range(constants.shape[0]): + local_dict["CONSTANT_{}".format(d)]=constants[d] + return _mod(**local_dict) + return partial(wrapper_fn, mod) + + def expr_to_numpy_fn(self, expr): + + def wrapper_fn(_expr, x, extra_local_dict={}): + local_dict = {} + for d in range(x.shape[1]): + local_dict["x_{}".format(d)]=x[:, d] + local_dict.update(extra_local_dict) + variables_symbols = sp.symbols(' '.join(["x_{}".format(d) for d in range(x.shape[1])])) + extra_symbols = list(extra_local_dict.keys()) + if len(extra_symbols)>0: + extra_symbols = sp.symbols(' '.join(extra_symbols)) + else: + extra_symbols=() + np_fn = sp.lambdify((*variables_symbols, *extra_symbols), _expr, modules='numpy') + return np_fn(**local_dict) + + return partial(wrapper_fn, expr) + + def tree_to_numpy_fn(self, tree): + expr = self.tree_to_sympy_expr(tree) + return self.expr_to_numpy_fn(expr) + + def tree_to_numexpr_fn(self, tree): + infix = tree.infix() + numexpr_equivalence = { + "add": "+", + "sub": "-", + "mul": "*", + "pow": "**", + "inv": "1/", + } + + for old, new in numexpr_equivalence.items(): + infix=infix.replace(old, new) + + def get_vals(dim, val): + vals_ar = np.empty((dim,)) + vals_ar[:] = val + return vals_ar + + def wrapped_numexpr_fn(_infix, x, extra_local_dict={}): + assert isinstance(x, np.ndarray) and len(x.shape)==2 + local_dict = {} + for d in range(self.params.max_input_dimension): + if "x_{}".format(d) in _infix: + if d >= x.shape[1]: + local_dict["x_{}".format(d)]=np.zeros(x.shape[0]) + else: + local_dict["x_{}".format(d)]=x[:,d] + local_dict.update(extra_local_dict) + try: + vals = ne.evaluate(_infix, local_dict=local_dict) + if len(vals.shape)==0: + vals = get_vals(x.shape[0], vals) + except Exception as e: + print(e) + print("problem with tree", _infix) + traceback.format_exc() + vals = get_vals(x.shape[0], np.nan) + return vals[:, None] + return partial(wrapped_numexpr_fn, infix) + + def sympy_expr_to_tree(self, expr): + prefix = self.sympy_to_prefix(expr) + return self.encoder.decode(prefix) + + def round_expr(self, expr, decimals=4): + with timeout(1): + expr = expr.xreplace(Transform(lambda x: x.round(decimals), lambda x: isinstance(x, sp.Float))) + return expr + + def float_to_int_expr(self, expr): + floats = expr.atoms(sp.Float) + ints = [fl for fl in floats if int(fl)==fl] + expr = expr.xreplace(dict(zip(ints, [int(i) for i in ints]))) + return expr + + def apply_fn(self, tree, fn_stack=[]): + expr = self.tree_to_sympy_expr(tree) + for (fn, arg) in fn_stack: + expr = getattr(self, fn)(expr=expr, **arg) + new_tree = self.sympy_expr_to_tree(expr) + if new_tree is None: + new_tree = tree + return new_tree + + def write_infix(self, token, args): + """ + Infix representation. + + """ + if token == "add": + return f"({args[0]})+({args[1]})" + elif token == "sub": + return f"({args[0]})-({args[1]})" + elif token == "mul": + return f"({args[0]})*({args[1]})" + elif token == "div": + return f"({args[0]})/({args[1]})" + if token == "pow": + return f"({args[0]})**({args[1]})" + elif token == "idiv": + return f"idiv({args[0]},{args[1]})" + elif token == "mod": + return f"({args[0]})%({args[1]})" + elif token == "abs": + return f"Abs({args[0]})" + elif token == "inv": + return f"1/({args[0]})" + elif token == "pow2": + return f"({args[0]})**2" + elif token == "pow3": + return f"({args[0]})**3" + elif token in all_operators: + return f"{token}({args[0]})" + else: + return token + raise InvalidPrefixExpression( + f"Unknown token in prefix expression: {token}, with arguments {args}" + ) + + def _prefix_to_sympy_compatible_infix(self, expr): + """ + Parse an expression in prefix mode, and output it in either: + - infix mode (returns human readable string) + - develop mode (returns a dictionary with the simplified expression) + """ + if len(expr) == 0: + raise InvalidPrefixExpression("Empty prefix list.") + t = expr[0] + if t in all_operators: + args = [] + l1 = expr[1:] + for _ in range(all_operators[t]): + i1, l1 = self._prefix_to_sympy_compatible_infix(l1) + args.append(i1) + return self.write_infix(t, args), l1 + else: # leaf + try: + float(t) + t = str(t) + except ValueError: + t=t + return t, expr[1:] + + def prefix_to_sympy_compatible_infix(self, expr): + """ + Convert prefix expressions to a format that SymPy can parse. + """ + p, r = self._prefix_to_sympy_compatible_infix(expr) + if len(r) > 0: + raise InvalidPrefixExpression( + f'Incorrect prefix expression "{expr}". "{r}" was not parsed.' + ) + return f"({p})" + + def _sympy_to_prefix(self, op, expr): + """ + Parse a SymPy expression given an initial root operator. + """ + n_args = len(expr.args) + + # assert (op == 'add' or op == 'mul') and (n_args >= 2) or (op != 'add' and op != 'mul') and (1 <= n_args <= 2) + + # square root + # if op == 'pow': + # if isinstance(expr.args[1], sp.Rational) and expr.args[1].p == 1 and expr.args[1].q == 2: + # return ['sqrt'] + self.sympy_to_prefix(expr.args[0]) + # elif str(expr.args[1])=='2': + # return ['sqr'] + self.sympy_to_prefix(expr.args[0]) + # elif str(expr.args[1])=='-1': + # return ['inv'] + self.sympy_to_prefix(expr.args[0]) + # elif str(expr.args[1])=='-2': + # return ['inv', 'sqr'] + self.sympy_to_prefix(expr.args[0]) + + # parse children + parse_list = [] + for i in range(n_args): + if i == 0 or i < n_args - 1: + parse_list.append(op) + parse_list += self.sympy_to_prefix(expr.args[i]) + + return parse_list + + def sympy_to_prefix(self, expr): + """ + Convert a SymPy expression to a prefix one. + """ + if isinstance(expr, sp.Symbol): + return [str(expr)] + elif isinstance(expr, sp.Integer): + return [str(expr)] + elif isinstance(expr, sp.Float): + s = str(expr) + return [s] + elif isinstance(expr, sp.Rational): + return ["mul", str(expr.p), "pow", str(expr.q), "-1"] + elif expr == sp.EulerGamma: + return ["euler_gamma"] + elif expr == sp.E: + return ["e"] + elif expr == sp.pi: + return ["pi"] + + # if we want div and sub + # if isinstance(expr, sp.Mul) and len(expr.args)==2: + # if isinstance(expr.args[0], sp.Mul) and isinstance(expr.args[0].args[0], sp.Pow): return ['div']+self.sympy_to_prefix(expr.args[1])+self.sympy_to_prefix(expr.args[0].args[1]) + # if isinstance(expr.args[1], sp.Mul) and isinstance(expr.args[1].args[0], sp.Pow): return ['div']+self.sympy_to_prefix(expr.args[0])+self.sympy_to_prefix(expr.args[1].args[1]) + # if isinstance(expr, sp.Add) and len(expr.args)==2: + # if isinstance(expr.args[0], sp.Mul) and str(expr.args[0].args[0])=='-1': return ['sub']+self.sympy_to_prefix(expr.args[1])+self.sympy_to_prefix(expr.args[0].args[1]) + # if isinstance(expr.args[1], sp.Mul) and str(expr.args[1].args[0])=='-1': return ['sub']+self.sympy_to_prefix(expr.args[0])+self.sympy_to_prefix(expr.args[1].args[1]) + + # if isinstance(expr, sp.Pow) and str(expr.args[1])=='-1': + # return ['inv'] + self.sympy_to_prefix(expr.args[0]) + + # SymPy operator + for op_type, op_name in self.SYMPY_OPERATORS.items(): + if isinstance(expr, op_type): + return self._sympy_to_prefix(op_name, expr) + + # Unknown operator + return self._sympy_to_prefix(str(type(expr)), expr) + + SYMPY_OPERATORS = { + # Elementary functions + sp.Add: "add", + sp.Mul: "mul", + sp.Mod: "mod", + sp.Pow: "pow", + # Misc + sp.Abs: "abs", + sp.sign: "sign", + sp.Heaviside: "step", + # Exp functions + sp.exp: "exp", + sp.log: "log", + # Trigonometric Functions + sp.sin: "sin", + sp.cos: "cos", + sp.tan: "tan", + # Trigonometric Inverses + sp.asin: "arcsin", + sp.acos: "arccos", + sp.atan: "arctan", + } \ No newline at end of file diff --git a/symbolicregression/envs/utils.py b/symbolicregression/envs/utils.py new file mode 100644 index 0000000..9bbd1a2 --- /dev/null +++ b/symbolicregression/envs/utils.py @@ -0,0 +1,59 @@ +def zip_dic(lst): + dico = {} + for d in lst: + for k in d: + if k not in dico: + dico[k] = [] + dico[k].append(d[k]) + for k in dico: + if isinstance(dico[k][0], dict): + dico[k] = zip_dic(dico[k]) + return dico + + +def unsqueeze_dic(dico): + dico_copy = {} + for d in dico: + if isinstance(dico[d], dict): + dico_copy[d] = unsqueeze_dic(dico[d]) + else: + dico_copy[d] = [dico[d]] + return dico_copy + + +def squeeze_dic(dico): + dico_copy = {} + for d in dico: + if isinstance(dico[d], dict): + dico_copy[d] = squeeze_dic(dico[d]) + else: + dico_copy[d] = dico[d][0] + return dico_copy + + +def chunks(lst, n): + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i : i + n] + + +def getSizeOfNestedList(listOfElem): + """Get number of elements in a nested list""" + count = 0 + # Iterate over the list + for elem in listOfElem: + # Check if type of element is list + if type(elem) == list: + # Again call this function to get the size of this element + count += getSizeOfNestedList(elem) + else: + count += 1 + return count + + +class ZMQNotReady(Exception): + pass + + +class ZMQNotReadySample: + pass diff --git a/symbolicregression/metrics.py b/symbolicregression/metrics.py new file mode 100644 index 0000000..29edadd --- /dev/null +++ b/symbolicregression/metrics.py @@ -0,0 +1,163 @@ +from sklearn.metrics import r2_score, mean_squared_error +from collections import defaultdict +import numpy as np +import scipy + +def compute_metrics(infos, metrics="r2"): + results = defaultdict(list) + if metrics == "": + return {} + + if "true" in infos: + true, predicted = infos["true"], infos["predicted"] + assert len(true) == len(predicted), "issue with len, true: {}, predicted: {}".format(len(true), len(predicted)) + for i in range(len(true)): + if predicted[i] is None: continue + if len(true[i].shape)==2: + true[i]=true[i][:,0] + if len(predicted[i].shape)==2: + predicted[i]=predicted[i][:,0] + assert true[i].shape == predicted[i].shape, "Problem with shapes: {}, {}".format(true[i].shape, predicted[i].shape) + + for metric in metrics.split(","): + if metric == "r2": + true, predicted = infos["true"], infos["predicted"] + for i in range(len(true)): + if predicted[i] is None or np.isnan(np.min(predicted[i])): + #print(predicted[i]) + results[metric].append(np.nan) + else: + try: + results[metric].append(r2_score(true[i], predicted[i])) + except Exception as e: + #print(e, metric, true[i], predicted[i]) + results[metric].append(np.nan) + if metric == "r2_zero": + true, predicted = infos["true"], infos["predicted"] + for i in range(len(true)): + if predicted[i] is None or np.isnan(np.min(predicted[i])): + #print(predicted[i]) + results[metric].append(np.nan) + else: + try: + results[metric].append(max(0, r2_score(true[i], predicted[i]))) + except Exception as e: + #print(e, metric, true[i], predicted[i]) + results[metric].append(np.nan) + + elif metric.startswith("accuracy_l1"): + if metric == "accuracy_l1": + atol, rtol = 0.0, 0.1 + tolerance_point = 0.95 + elif metric == "accuracy_l1_biggio": + ## default is biggio et al. + atol, rtol = 1e-3, 0.05 + tolerance_point = 0.95 + else: + atol = 0 #float(metric.split("_")[-3]) + rtol = float(metric.split("_")[-1]) + tolerance_point = 0.95 #float(metric.split("_")[-1]) + + true, predicted = infos["true"], infos["predicted"] + for i in range(len(true)): + if predicted[i] is None or np.isnan(np.min(predicted[i])): + results[metric].append(np.nan) + else: + try: + is_close = np.isclose(predicted[i], true[i], atol=atol, rtol=rtol) + results[metric].append(float(is_close.mean()>=tolerance_point)) + except Exception as e: + print(e, metric, true[i], predicted[i]) + results[metric].append(np.nan) + + elif metric == "_mse": + true, predicted = infos["true"], infos["predicted"] + for i in range(len(true)): + if predicted[i] is None or np.isnan(np.min(predicted[i])): + results[metric].append(np.nan) + else: + try: + results[metric].append(mean_squared_error(true[i], predicted[i])) + except Exception as e: + results[metric].append(np.nan) + elif metric == "_nmse": + true, predicted = infos["true"], infos["predicted"] + for i in range(len(true)): + if predicted[i] is None or np.isnan(np.min(predicted[i])): + results[metric].append(np.nan) + else: + try: + mean_y = np.mean(true[i]) + NMSE = (np.mean(np.square(true[i]- predicted[i])))/mean_y + results[metric].append(NMSE) + except Exception as e: + results[metric].append(np.nan) + elif metric == "_rmse": + true, predicted = infos["true"], infos["predicted"] + for i in range(len(true)): + if predicted[i] is None or np.isnan(np.min(predicted[i])): + results[metric].append(np.nan) + else: + try: + results[metric].append(mean_squared_error(true[i], predicted[i], squared=False)) + except Exception as e: + results[metric].append(np.nan) + elif metric == "_complexity": + if "predicted_tree" not in infos: + results[metric].extend([np.nan for _ in range(len(infos["true"]))]) + continue + predicted_tree = infos["predicted_tree"] + for i in range(len(predicted_tree)): + if predicted_tree[i] is None: + results[metric].append(np.nan) + else: + results[metric].append(len(predicted_tree[i].prefix().split(","))) + + elif metric == "_relative_complexity": + if "tree" not in infos or "predicted_tree" not in infos: + results[metric].extend([np.nan for _ in range(len(infos["true"]))]) + continue + tree = infos["tree"] + predicted_tree = infos["predicted_tree"] + for i in range(len(predicted_tree)): + if predicted_tree[i] is None: + results[metric].append(np.nan) + else: + results[metric].append(len(predicted_tree[i].prefix().split(",")) - len(tree[i].prefix().split(","))) + + elif metric == "is_symbolic_solution": + + true, predicted = infos["true"], infos["predicted"] + for i in range(len(true)): + if predicted[i] is None or np.isnan(np.min(predicted[i])): + results[metric].append(np.nan) + else: + try: + diff = true[i] - predicted[i] + div = true[i] / (predicted[i] + 1e-100) + std_diff = scipy.linalg.norm( + np.abs(diff - diff.mean(0)) + ) + std_div = scipy.linalg.norm( + np.abs(div - div.mean(0)) + ) + if std_diff<1e-10 and std_div<1e-10: results[metric].append(1.0) + else: results[metric].append(0.0) + except Exception as e: + #print(e, metric, infos["predicted_tree"][i].infix()) + results[metric].append(np.nan) + + elif metric == "_l1_error": + true, predicted = infos["true"], infos["predicted"] + for i in range(len(true)): + if predicted[i] is None or np.isnan(np.min(predicted[i])): + results[metric].append(np.nan) + else: + try: + l1_error = np.mean(np.abs((true[i] - predicted[i]))) + if np.isnan(l1_error): results[metric].append(np.infty) + else: results[metric].append(l1_error) + except Exception as e: + #print(e, metric, true[i], predicted[i], infos["predicted_tree"][i].infix()) + results[metric].append(np.nan) + return results diff --git a/symbolicregression/model/__init__.py b/symbolicregression/model/__init__.py new file mode 100644 index 0000000..8a4bcc0 --- /dev/null +++ b/symbolicregression/model/__init__.py @@ -0,0 +1,65 @@ +# Copyright (c) 2020-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch +from .embedders import LinearPointEmbedder +from .transformer import TransformerModel +from .sklearn_wrapper import SymbolicTransformerRegressor +from .model_wrapper import ModelWrapper + +def build_modules(env, params): + """ + Build modules. + """ + modules = {} + modules["embedder"] = LinearPointEmbedder(params, env) + env.get_length_after_batching = modules["embedder"].get_length_after_batching + + modules["encoder"] = TransformerModel( + params, + env.float_id2word, + is_encoder=True, + with_output=False, + use_prior_embeddings=True, + positional_embeddings=params.enc_positional_embeddings + + ) + modules["decoder"] = TransformerModel( + params, + env.equation_id2word, + is_encoder=False, + with_output=True, + use_prior_embeddings=False, + positional_embeddings=params.dec_positional_embeddings + ) + + # reload pretrained modules + if params.reload_model != "": + logger.info(f"Reloading modules from {params.reload_model} ...") + reloaded = torch.load(params.reload_model) + for k, v in modules.items(): + assert k in reloaded + if all([k2.startswith("module.") for k2 in reloaded[k].keys()]): + reloaded[k] = { + k2[len("module.") :]: v2 for k2, v2 in reloaded[k].items() + } + v.load_state_dict(reloaded[k]) + + # log + for k, v in modules.items(): + logger.debug(f"{v}: {v}") + for k, v in modules.items(): + logger.info( + f"Number of parameters ({k}): {sum([p.numel() for p in v.parameters() if p.requires_grad])}" + ) + + # cuda + if not params.cpu: + for v in modules.values(): + v.cuda() + + return modules diff --git a/symbolicregression/model/embedders.py b/symbolicregression/model/embedders.py new file mode 100644 index 0000000..f2678de --- /dev/null +++ b/symbolicregression/model/embedders.py @@ -0,0 +1,141 @@ +from typing import Tuple, List +from abc import ABC, abstractmethod +import torch +import torch.nn as nn +from symbolicregression.utils import to_cuda +import torch.nn.functional as F + +MultiDimensionalFloat = List[float] +XYPair = Tuple[MultiDimensionalFloat, MultiDimensionalFloat] +Sequence = List[XYPair] + + +class Embedder(ABC, nn.Module): + """ + Base class for embedders, transforms a sequence of pairs into a sequence of embeddings. + """ + + def __init__(self): + super().__init__() + pass + + @abstractmethod + def forward(self, sequences: List[Sequence]) -> Tuple[torch.Tensor, torch.Tensor]: + pass + + @abstractmethod + def encode(self, sequences: List[Sequence]) -> List[torch.Tensor]: + pass + + def batch(self, seqs: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError + + def embed(self, batch: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + @abstractmethod + def get_length_after_batching(self, sequences: List[Sequence]) -> List[int]: + pass + +class LinearPointEmbedder(Embedder): + def __init__(self, params, env): + from .transformer import Embedding + + super().__init__() + self.env = env + self.params = params + self.input_dim = params.emb_emb_dim + self.output_dim = params.enc_emb_dim + self.embeddings = Embedding( + len(self.env.float_id2word), + self.input_dim, + padding_idx=self.env.float_word2id[""], + ) + self.float_scalar_descriptor_len = (2 + self.params.mantissa_len) + self.total_dimension = self.params.max_input_dimension + self.params.max_output_dimension + self.float_vector_descriptor_len = self.float_scalar_descriptor_len * self.total_dimension + + self.activation_fn = F.relu + size = self.float_vector_descriptor_len*self.input_dim + hidden_size = size * self.params.emb_expansion_factor + self.hidden_layers = nn.ModuleList() + self.hidden_layers.append(nn.Linear(size, hidden_size)) + for i in range(self.params.n_emb_layers-1): + self.hidden_layers.append(nn.Linear(hidden_size, hidden_size)) + self.fc = nn.Linear(hidden_size, self.output_dim) + self.max_seq_len = self.params.max_len + + def compress( + self, sequences_embeddings: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Takes: (N_max * (d_in+d_out)*(2+mantissa_len), B, d) tensors + Returns: (N_max, B, d) + + """ + max_len, bs, float_descriptor_length, dim = sequences_embeddings.size() + sequences_embeddings = sequences_embeddings.view(max_len, bs, -1) + for layer in self.hidden_layers: sequences_embeddings = self.activation_fn(layer(sequences_embeddings)) + sequences_embeddings = self.fc(sequences_embeddings) + return sequences_embeddings + + def forward(self, sequences: List[Sequence]) -> Tuple[torch.Tensor, torch.Tensor]: + sequences = self.encode(sequences) + sequences, sequences_len = self.batch(sequences) + sequences, sequences_len = to_cuda(sequences, sequences_len, use_cpu=self.fc.weight.device.type=="cpu") + sequences_embeddings = self.embed(sequences) + sequences_embeddings = self.compress(sequences_embeddings) + return sequences_embeddings, sequences_len + + def encode(self, sequences: List[Sequence]) -> List[torch.Tensor]: + res = [] + for seq in sequences: + seq_toks = [] + for x, y in seq: + x_toks = self.env.float_encoder.encode(x) + y_toks = self.env.float_encoder.encode(y) + input_dim = int(len(x_toks) / (2 + self.params.mantissa_len)) + output_dim = int(len(y_toks) / (2 + self.params.mantissa_len)) + x_toks = [ + *x_toks, + *[ + "" + for _ in range( + (self.params.max_input_dimension - input_dim) + * self.float_scalar_descriptor_len + ) + ], + ] + y_toks = [ + *y_toks, + *[ + "" + for _ in range( + (self.params.max_output_dimension - output_dim) + * self.float_scalar_descriptor_len + ) + ], + ] + toks = [*x_toks, *y_toks] + seq_toks.append([self.env.float_word2id[tok] for tok in toks]) + res.append(torch.LongTensor(seq_toks)) + return res + + def batch(self, seqs: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + pad_id = self.env.float_word2id[""] + lengths = [len(x) for x in seqs] + bs, slen = len(lengths), max(lengths) + sent = torch.LongTensor(slen, bs, self.float_vector_descriptor_len).fill_(pad_id) + for i, seq in enumerate(seqs): + sent[0 : len(seq), i, :] = seq + return sent, torch.LongTensor(lengths) + + def embed(self, batch: torch.Tensor) -> torch.Tensor: + return self.embeddings(batch) + + def get_length_after_batching(self, seqs: List[Sequence]) -> torch.Tensor: + lengths = torch.zeros(len(seqs), dtype=torch.long) + for i, seq in enumerate(seqs): + lengths[i] = len(seq) + assert lengths.max() <= self.max_seq_len, "issue with lengths after batching" + return lengths diff --git a/symbolicregression/model/model_wrapper.py b/symbolicregression/model/model_wrapper.py new file mode 100644 index 0000000..cf2d2ca --- /dev/null +++ b/symbolicregression/model/model_wrapper.py @@ -0,0 +1,106 @@ +import numpy as np +import torch +import torch.nn as nn + +def chunks(lst, n): + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i : i + n] + +class ModelWrapper(nn.Module): + """""" + def __init__(self, + env=None, + embedder=None, + encoder=None, + decoder=None, + beam_type="search", + beam_length_penalty=1, + beam_size=1, + beam_early_stopping=True, + max_generated_output_len=200, + beam_temperature=1., + ): + super().__init__() + + self.env = env + self.embedder = embedder + self.encoder = encoder + self.decoder = decoder + self.beam_type = beam_type + self.beam_early_stopping = beam_early_stopping + self.max_generated_output_len = max_generated_output_len + self.beam_size = beam_size + self.beam_length_penalty = beam_length_penalty + self.beam_temperature = beam_temperature + self.device=next(self.embedder.parameters()).device + + @torch.no_grad() + def forward( + self, + input, + ): + + """ + x: bags of sequences (B, T) + """ + + env = self.env + embedder, encoder, decoder = self.embedder, self.encoder, self.decoder + + B, T = len(input), max([len(xi) for xi in input]) + outputs = [] + for chunk in chunks(np.arange(B), min(int(10000/T), int(100000/self.beam_size/self.max_generated_output_len))): + x, x_len = embedder([input[idx] for idx in chunk]) + encoded = encoder("fwd", x=x, lengths=x_len, causal=False).transpose(0,1) + bs = encoded.shape[0] + + ### Greedy solution. + generations, _ = decoder.generate( + encoded, + x_len, + sample_temperature=None, + max_len=self.max_generated_output_len, + ) + + generations = generations.unsqueeze(-1).view(generations.shape[0], bs, 1) + generations = generations.transpose(0,1).transpose(1,2).cpu().tolist() + generations = [list(filter(lambda x: x is not None, [env.idx_to_infix(hyp[1:-1], is_float=False, str_array=False) for hyp in generations[i]])) for i in range(bs)] + + if self.beam_type == "search": + _, _, search_generations = decoder.generate_beam( + encoded, + x_len, + beam_size=self.beam_size, + length_penalty=self.beam_length_penalty, + max_len=self.max_generated_output_len, + early_stopping=self.beam_early_stopping, + ) + search_generations = [sorted([hyp for hyp in search_generations[i].hyp], key=lambda s: s[0], reverse=True) for i in range(bs)] + search_generations = [list(filter(lambda x: x is not None, [env.idx_to_infix(hyp.cpu().tolist()[1:], is_float=False, str_array=False) for (_, hyp) in search_generations[i]])) for i in range(bs)] + for i in range(bs): + generations[i].extend(search_generations[i]) + + elif self.beam_type == "sampling": + num_samples = self.beam_size + encoded = (encoded.unsqueeze(1) + .expand((bs, num_samples) + encoded.shape[1:]) + .contiguous() + .view((bs * num_samples,) + encoded.shape[1:]) + ) + x_len = x_len.unsqueeze(1).expand(bs, num_samples).contiguous().view(-1) + sampling_generations, _ = decoder.generate( + encoded, + x_len, + sample_temperature = self.beam_temperature, + max_len=self.max_generated_output_len + ) + sampling_generations = sampling_generations.unsqueeze(-1).view(sampling_generations.shape[0], bs, num_samples) + sampling_generations = sampling_generations.transpose(0, 1).transpose(1, 2).cpu().tolist() + sampling_generations = [list(filter(lambda x: x is not None, [env.idx_to_infix(hyp[1:-1], is_float=False, str_array=False) for hyp in sampling_generations[i]])) for i in range(bs)] + for i in range(bs): + generations[i].extend(sampling_generations[i]) + else: + raise NotImplementedError + outputs.extend(generations) + return outputs \ No newline at end of file diff --git a/symbolicregression/model/sklearn_wrapper.py b/symbolicregression/model/sklearn_wrapper.py new file mode 100644 index 0000000..f1b837e --- /dev/null +++ b/symbolicregression/model/sklearn_wrapper.py @@ -0,0 +1,233 @@ +import math, time, copy +import numpy as np +import torch +from collections import defaultdict +from symbolicregression.metrics import compute_metrics +from sklearn.base import BaseEstimator +import symbolicregression.model.utils_wrapper as utils_wrapper +import traceback + +class SymbolicTransformerRegressor(BaseEstimator): + + def __init__(self, + model=None, + max_input_points=10000, + max_number_bags=-1, + stop_refinement_after=1, + n_trees_to_refine=1, + rescale=True + ): + + self.max_input_points = max_input_points + self.max_number_bags = max_number_bags + self.model = model + self.stop_refinement_after = stop_refinement_after + self.n_trees_to_refine = n_trees_to_refine + self.rescale = rescale + + def set_args(self, args={}): + for arg, val in args.items(): + assert hasattr(self, arg), "{} arg does not exist".format(arg) + setattr(self, arg, val) + + def fit( + self, + X, + Y, + verbose=False + ): + self.start_fit = time.time() + + if not isinstance(X, list): + X = [X] + Y = [Y] + n_datasets = len(X) + + scaler = utils_wrapper.StandardScaler() if self.rescale else None + scale_params = {} + if scaler is not None: + scaled_X = [] + for i, x in enumerate(X): + scaled_X.append(scaler.fit_transform(x)) + scale_params[i]=scaler.get_params() + else: + scaled_X = X + + inputs, inputs_ids = [], [] + for seq_id in range(len(scaled_X)): + for seq_l in range(len(scaled_X[seq_id])): + y_seq = Y[seq_id] + if len(y_seq.shape)==1: + y_seq = np.expand_dims(y_seq,-1) + if seq_l%self.max_input_points == 0: + inputs.append([]) + inputs_ids.append(seq_id) + inputs[-1].append([scaled_X[seq_id][seq_l], y_seq[seq_l]]) + + if self.max_number_bags>0: + inputs = inputs[:self.max_number_bags] + inputs_ids = inputs_ids[:self.max_number_bags] + + forward_time=time.time() + outputs = self.model(inputs) ##Forward transformer: returns predicted functions + if verbose: print("Finished forward in {} secs".format(time.time()-forward_time)) + + candidates = defaultdict(list) + assert len(inputs) == len(outputs), "Problem with inputs and outputs" + for i in range(len(inputs)): + input_id = inputs_ids[i] + candidate = outputs[i] + candidates[input_id].extend(candidate) + assert len(candidates.keys())==n_datasets + + self.tree = {} + for input_id, candidates_id in candidates.items(): + if len(candidates_id)==0: + self.tree[input_id]=None + continue + + refined_candidates = self.refine(scaled_X[input_id], Y[input_id], candidates_id, verbose=verbose) + for i,candidate in enumerate(refined_candidates): + if scaler is not None: + refined_candidates[i]["predicted_tree"]=scaler.rescale_function(self.model.env, candidate["predicted_tree"], *scale_params[input_id]) + else: + refined_candidates[i]["predicted_tree"]=candidate["predicted_tree"] + self.tree[input_id] = refined_candidates + + @torch.no_grad() + def evaluate_tree(self, tree, X, y, metric): + numexpr_fn = self.model.env.simplifier.tree_to_numexpr_fn(tree) + y_tilde = numexpr_fn(X)[:,0] + metrics = compute_metrics({"true": [y], "predicted": [y_tilde], "predicted_tree": [tree]}, metrics=metric) + return metrics[metric][0] + + def order_candidates(self, X, y, candidates, metric="_mse", verbose=False): + scores = [] + for candidate in candidates: + if metric not in candidate: + score = self.evaluate_tree(candidate["predicted_tree"], X, y, metric) + if math.isnan(score): + score = np.infty if metric.startswith("_") else -np.infty + else: + score = candidates[metric] + scores.append(score) + ordered_idx = np.argsort(scores) + if not metric.startswith("_"): ordered_idx=list(reversed(ordered_idx)) + candidates = [candidates[i] for i in ordered_idx] + return candidates + + def refine(self, X, y, candidates, verbose): + refined_candidates = [] + + ## For skeleton model + for i, candidate in enumerate(candidates): + candidate_skeleton, candidate_constants = self.model.env.generator.function_to_skeleton(candidate, constants_with_idx=True) + if "CONSTANT" in candidate_constants: + candidates[i] = self.model.env.wrap_equation_floats(candidate_skeleton, np.random.randn(len(candidate_constants))) + + candidates = [{"refinement_type": "NoRef", "predicted_tree": candidate, "time": time.time()-self.start_fit} for candidate in candidates] + candidates = self.order_candidates(X, y, candidates, metric="_mse", verbose=verbose) + + ## REMOVE SKELETON DUPLICATAS + skeleton_candidates, candidates_to_remove = {}, [] + for i, candidate in enumerate(candidates): + skeleton_candidate, _ = self.model.env.generator.function_to_skeleton(candidate["predicted_tree"], constants_with_idx=False) + if skeleton_candidate.infix() in skeleton_candidates: + candidates_to_remove.append(i) + else: + skeleton_candidates[skeleton_candidate.infix()]=1 + if verbose: print("Removed {}/{} skeleton duplicata".format(len(candidates_to_remove), len(candidates))) + + candidates = [candidates[i] for i in range(len(candidates)) if i not in candidates_to_remove] + if self.n_trees_to_refine>0: + candidates_to_refine = candidates[:self.n_trees_to_refine] + else: + candidates_to_refine = copy.deepcopy(candidates) + + for candidate in candidates_to_refine: + refinement_strategy = utils_wrapper.BFGSRefinement() + candidate_skeleton, candidate_constants = self.model.env.generator.function_to_skeleton(candidate["predicted_tree"], constants_with_idx=True) + try: + refined_candidate = refinement_strategy.go(env=self.model.env, + tree=candidate_skeleton, + coeffs0=candidate_constants, + X=X, + y=y, + downsample=1024, + stop_after=self.stop_refinement_after) + + except Exception as e: + if verbose: + print(e) + #traceback.format_exc() + continue + + if refined_candidate is not None: + refined_candidates.append({ + "refinement_type": "BFGS", + "predicted_tree": refined_candidate, + }) + candidates.extend(refined_candidates) + candidates = self.order_candidates(X, y, candidates, metric="r2") + + for candidate in candidates: + if "time" not in candidate: + candidate["time"]=time.time()-self.start_fit + return candidates + + def __str__(self): + if hasattr(self, "tree"): + for tree_idx in range(len(self.tree)): + for gen in self.tree[tree_idx]: + print(gen) + return "Transformer" + + def retrieve_refinements_types(self): + return ["BFGS", "NoRef"] + + def retrieve_tree(self, refinement_type=None, tree_idx=0, with_infos=False): + if tree_idx == -1: idxs = [_ for _ in range(len(self.tree))] + else: idxs = [tree_idx] + best_trees = [] + for idx in idxs: + best_tree = copy.deepcopy(self.tree[idx]) + if best_tree and refinement_type is not None: + best_tree = list(filter(lambda gen: gen["refinement_type"]==refinement_type, best_tree)) + if not best_tree: + if with_infos: + best_trees.append({"predicted_tree": None, "refinement_type": None, "time": None}) + else: + best_trees.append(None) + else: + if with_infos: + best_trees.append(best_tree[0]) + else: + best_trees.append(best_tree[0]["predicted_tree"]) + if tree_idx != -1: return best_trees[0] + else: return best_trees + + + def predict(self, X, refinement_type=None, tree_idx=0, batch=False): + if not isinstance(X, list): + X = [X] + res = [] + if batch: + tree = self.retrieve_tree(refinement_type=refinement_type, tree_idx = -1) + for tree_idx in range(len(tree)): + X_idx = X[tree_idx] + if tree[tree_idx] is None: + res.append(None) + else: + numexpr_fn = self.model.env.simplifier.tree_to_numexpr_fn(tree[tree_idx]) + y = numexpr_fn(X_idx)[:,0] + res.append(y) + return res + else: + X_idx = X[tree_idx] + tree = self.retrieve_tree(refinement_type=refinement_type, tree_idx = tree_idx) + if tree is not None: + numexpr_fn = self.model.env.simplifier.tree_to_numexpr_fn(tree) + y = numexpr_fn(X_idx)[:,0] + return y + else: + return None \ No newline at end of file diff --git a/symbolicregression/model/transformer.py b/symbolicregression/model/transformer.py new file mode 100644 index 0000000..f501399 --- /dev/null +++ b/symbolicregression/model/transformer.py @@ -0,0 +1,906 @@ +# Copyright (c) 2020-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from logging import getLogger +import math +import itertools +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +N_MAX_POSITIONS = 4096 # maximum input sequence length + + +logger = getLogger() + + +def Embedding(num_embeddings, embedding_dim, padding_idx=None): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) + if padding_idx is not None: + nn.init.constant_(m.weight[padding_idx], 0) + return m + + +def create_sinusoidal_embeddings(n_pos, dim, out): + position_enc = np.array( + [ + [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] + for pos in range(n_pos) + ] + ) + out.detach_() + out.requires_grad = False + out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) + out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) + + +def get_masks(slen, lengths, causal): + """ + Generate hidden states mask, and optionally an attention mask. + """ + assert lengths.max().item() <= slen + bs = lengths.size(0) + alen = torch.arange(slen, dtype=torch.long, device=lengths.device) + mask = alen < lengths[:, None] + + # attention mask is the same as mask, or triangular inferior attention (causal) + if causal: + attn_mask = alen[None, None, :].repeat(bs, slen, 1) <= alen[None, :, None] + else: + attn_mask = mask + + # sanity check + assert mask.size() == (bs, slen) + assert causal is False or attn_mask.size() == (bs, slen, slen) + + return mask, attn_mask + + +class MultiHeadAttention(nn.Module): + + NEW_ID = itertools.count() + + def __init__(self, n_heads, dim, src_dim, dropout, normalized_attention): + super().__init__() + self.layer_id = next(MultiHeadAttention.NEW_ID) + self.dim = dim + self.src_dim = src_dim + self.n_heads = n_heads + self.dropout = dropout + self.normalized_attention = normalized_attention + assert self.dim % self.n_heads == 0 + + self.q_lin = nn.Linear(dim, dim) + self.k_lin = nn.Linear(src_dim, dim) + self.v_lin = nn.Linear(src_dim, dim) + self.out_lin = nn.Linear(dim, dim) + if self.normalized_attention: + self.attention_scale = nn.Parameter( + torch.tensor(1.0 / math.sqrt(dim // n_heads)) + ) + + def forward(self, input, mask=None, kv=None, use_cache=False): + """ + Self-attention (if kv is None) + or attention over source sentence (provided by kv). + Input is (bs, qlen, dim) + Mask is (bs, klen) (non-causal) or (bs, klen, klen) + """ + assert not (use_cache and self.cache is None) + bs, qlen, dim = input.size() + if kv is None: + klen = qlen if not use_cache else self.cache["slen"] + qlen + else: + klen = kv.size(1) + assert dim == self.dim, "Dimensions do not match: %s input vs %s configured" % ( + dim, + self.dim, + ) + n_heads = self.n_heads + dim_per_head = dim // n_heads + + def shape(x): + """projection""" + return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2) + + def unshape(x): + """compute context""" + return ( + x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head) + ) + + q = shape(self.q_lin(input)) # (bs, n_heads, qlen, dim_per_head) + if kv is None: + k = shape(self.k_lin(input)) # (bs, n_heads, qlen, dim_per_head) + v = shape(self.v_lin(input)) # (bs, n_heads, qlen, dim_per_head) + elif not use_cache or self.layer_id not in self.cache: + k = v = kv + k = shape(self.k_lin(k)) # (bs, n_heads, qlen, dim_per_head) + v = shape(self.v_lin(v)) # (bs, n_heads, qlen, dim_per_head) + + if use_cache: + if self.layer_id in self.cache: + if kv is None: + k_, v_ = self.cache[self.layer_id] + k = torch.cat([k_, k], dim=2) # (bs, n_heads, klen, dim_per_head) + v = torch.cat([v_, v], dim=2) # (bs, n_heads, klen, dim_per_head) + else: + k, v = self.cache[self.layer_id] + self.cache[self.layer_id] = (k, v) + if self.normalized_attention: + q = F.normalize(q, p=2, dim=-1) + k = F.normalize(k, p=2, dim=-1) + q = q * self.attention_scale + else: + q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head) + + scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen) + + if mask is not None: + mask_reshape = (bs, 1, qlen, klen) if mask.dim() == 3 else (bs, 1, 1, klen) + mask = ( + (mask == 0).view(mask_reshape).expand_as(scores) + ) # (bs, n_heads, qlen, klen) + scores.masked_fill_(mask, -float("inf")) # (bs, n_heads, qlen, klen) + + weights = F.softmax(scores.float(), dim=-1).type_as( + scores + ) # (bs, n_heads, qlen, klen) + weights = F.dropout( + weights, p=self.dropout, training=self.training + ) # (bs, n_heads, qlen, klen) + context = torch.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head) + context = unshape(context) # (bs, qlen, dim) + + if TransformerModel.STORE_OUTPUTS and not self.training: + self.outputs = weights.detach().cpu() + + return self.out_lin(context) + + +class TransformerFFN(nn.Module): + def __init__(self, in_dim, dim_hidden, out_dim, hidden_layers, dropout): + super().__init__() + self.dropout = dropout + self.hidden_layers = hidden_layers + self.midlin = nn.ModuleList() + self.lin1 = nn.Linear(in_dim, dim_hidden) + for i in range(1, self.hidden_layers): + self.midlin.append(nn.Linear(dim_hidden, dim_hidden)) + self.lin2 = nn.Linear(dim_hidden, out_dim) + + def forward(self, input): + x = self.lin1(input) + x = F.relu(x) + for mlin in self.midlin: + x = mlin(x) + x = F.relu(x) + x = self.lin2(x) + x = F.dropout(x, p=self.dropout, training=self.training) + return x + + +class TransformerModel(nn.Module): + + STORE_OUTPUTS = True + + def __init__(self, params, id2word, is_encoder, with_output, use_prior_embeddings, positional_embeddings): + """ + Transformer model (encoder or decoder). + """ + super().__init__() + + # encoder / decoder, output layer + self.dtype = torch.half if params.fp16 else torch.float + self.is_encoder = is_encoder + self.is_decoder = not is_encoder + self.with_output = with_output + + self.apex = params.nvidia_apex + + # dictionary + + self.id2word = id2word + self.word2id = {s: i for i, s in self.id2word.items()} + self.eos_index = self.word2id[""] + self.pad_index = self.word2id[""] + + self.n_words = len(self.id2word) + assert len(self.id2word) == self.n_words + + # model parameters + self.dim = ( + params.enc_emb_dim if is_encoder else params.dec_emb_dim + ) # 512 by default + self.src_dim = params.enc_emb_dim + self.hidden_dim = self.dim * 4 # 2048 by default + self.n_hidden_layers = ( + params.n_enc_hidden_layers if is_encoder else params.n_dec_hidden_layers + ) + self.n_heads = ( + params.n_enc_heads if is_encoder else params.n_dec_heads + ) # 8 by default + self.n_layers = params.n_enc_layers if is_encoder else params.n_dec_layers + self.dropout = params.dropout + self.attention_dropout = params.attention_dropout + self.norm_attention = params.norm_attention + assert ( + self.dim % self.n_heads == 0 + ), "transformer dim must be a multiple of n_heads" + + # embeddings + + if positional_embeddings is None or positional_embeddings == "alibi": + self.position_embeddings = None + elif positional_embeddings == "sinusoidal": + self.position_embeddings = Embedding(N_MAX_POSITIONS, self.dim) + create_sinusoidal_embeddings( + N_MAX_POSITIONS, self.dim, out=self.position_embeddings.weight + ) + elif positional_embeddings == "learnable": + self.position_embeddings = Embedding(N_MAX_POSITIONS, self.dim) + else: + raise NotImplementedError + + self.use_prior_embeddings = use_prior_embeddings + if not use_prior_embeddings: + self.embeddings = Embedding( + self.n_words, self.dim, padding_idx=self.pad_index + ) + else: self.embeddings = None + self.layer_norm_emb = nn.LayerNorm(self.dim, eps=1e-12) + + # transformer layers + self.attentions = nn.ModuleList() + self.layer_norm1 = nn.ModuleList() + self.ffns = nn.ModuleList() + self.layer_norm2 = nn.ModuleList() + if self.is_decoder: + self.layer_norm15 = nn.ModuleList() + self.encoder_attn = nn.ModuleList() + + for layer_id in range(self.n_layers): + self.attentions.append( + MultiHeadAttention( + self.n_heads, + self.dim, + self.dim, + dropout=self.attention_dropout, + normalized_attention=self.norm_attention, + ) + ) + self.layer_norm1.append(nn.LayerNorm(self.dim, eps=1e-12)) + if self.is_decoder: + self.layer_norm15.append(nn.LayerNorm(self.dim, eps=1e-12)) + self.encoder_attn.append( + MultiHeadAttention( + self.n_heads, + self.dim, + self.src_dim, + dropout=self.attention_dropout, + normalized_attention=self.norm_attention, + ) + ) + self.ffns.append( + TransformerFFN( + self.dim, + self.hidden_dim, + self.dim, + self.n_hidden_layers, + dropout=self.dropout, + ) + ) + self.layer_norm2.append(nn.LayerNorm(self.dim, eps=1e-12)) + + self.cache = None + + # output layer + if self.with_output: + assert not self.use_prior_embeddings + self.proj = nn.Linear( + self.dim, self.n_words, bias=True + ) ##added index for eos and tab + if params.share_inout_emb: + self.proj.weight = self.embeddings.weight + + def forward(self, mode, **kwargs): + """ + Forward function with different forward modes. + ### Small hack to handle PyTorch distributed. + """ + if mode == "fwd": + return self.fwd(**kwargs) + elif mode == "predict": + return self.predict(**kwargs) + else: + raise Exception("Unknown mode: %s" % mode) + + def fwd( + self, + x, + lengths, + causal, + src_enc=None, + src_len=None, + positions=None, + use_cache=False, + + ): + """ + Inputs: + `x` LongTensor(slen, bs), containing word indices + `lengths` LongTensor(bs), containing the length of each sentence + `causal` Boolean, if True, the attention is only done over previous hidden states + `positions` LongTensor(slen, bs), containing word positions + """ + # lengths = (x != self.pad_index).float().sum(dim=1) + # mask = x != self.pad_index + + + + # check inputs + slen, bs = x.size()[:2] + assert lengths.size(0) == bs + assert lengths.max().item() <= slen + x = x.transpose(0, 1) # batch size as dimension 0 + assert (src_enc is None) == (src_len is None) + if src_enc is not None: + assert self.is_decoder + assert src_enc.size(0) == bs + assert not (use_cache and self.cache is None) + + # generate masks + mask, attn_mask = get_masks(slen, lengths, causal) + if self.is_decoder and src_enc is not None: + src_mask = ( + torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) + < src_len[:, None] + ) + + # positions + if positions is None: + positions = x.new(slen).long() + positions = torch.arange(slen, out=positions).unsqueeze(0) + else: + assert positions.size() == (slen, bs) + positions = positions.transpose(0, 1) + + # do not recompute cached elements + if use_cache: + _slen = slen - self.cache["slen"] + x = x[:, -_slen:] + positions = positions[:, -_slen:] + mask = mask[:, -_slen:] + attn_mask = attn_mask[:, -_slen:] + + # all layer outputs + if TransformerModel.STORE_OUTPUTS and not self.training: + self.outputs = [] + + # embeddings + if not self.use_prior_embeddings: + tensor = self.embeddings(x) + else: + tensor = x + + if self.position_embeddings is not None: + tensor = tensor + self.position_embeddings(positions).expand_as(tensor) + tensor = self.layer_norm_emb(tensor) + tensor = F.dropout(tensor, p=self.dropout, training=self.training) + tensor *= mask.unsqueeze(-1).to(tensor.dtype) + if TransformerModel.STORE_OUTPUTS and not self.training: + self.outputs.append(tensor.detach().cpu()) + + # transformer layers + for i in range(self.n_layers): + + # self attention + self.attentions[i].cache = self.cache + attn = self.attentions[i](tensor, attn_mask, use_cache=use_cache) + attn = F.dropout(attn, p=self.dropout, training=self.training) + tensor = tensor + attn + tensor = self.layer_norm1[i](tensor) + + # encoder attention (for decoder only) + if self.is_decoder and src_enc is not None: + self.encoder_attn[i].cache = self.cache + attn = self.encoder_attn[i]( + tensor, src_mask, kv=src_enc, use_cache=use_cache + ) + attn = F.dropout(attn, p=self.dropout, training=self.training) + tensor = tensor + attn + tensor = self.layer_norm15[i](tensor) + + # FFN + tensor = tensor + self.ffns[i](tensor) + tensor = self.layer_norm2[i](tensor) + + tensor *= mask.unsqueeze(-1).to(tensor.dtype) + if TransformerModel.STORE_OUTPUTS and not self.training: + self.outputs.append(tensor.detach().cpu()) + + # update cache length + if use_cache: + self.cache["slen"] += tensor.size(1) + + # move back sequence length to dimension 0 + tensor = tensor.transpose(0, 1) + + return tensor + + def predict(self, tensor, pred_mask, y, get_scores): + """ + Given the last hidden state, compute word scores and/or the loss. + `pred_mask` is a ByteTensor of shape (slen, bs), filled with 1 when + we need to predict a word + `y` is a LongTensor of shape (pred_mask.sum(),) + `get_scores` is a boolean specifying whether we need to return scores + """ + x = tensor[pred_mask.unsqueeze(-1).expand_as(tensor)].view(-1, self.dim) + assert (y == self.pad_index).sum().item() == 0 + scores = self.proj(x).view(-1, self.n_words) + loss = F.cross_entropy(scores.float(), y, reduction="mean") + return scores, loss + + def generate(self, src_enc, src_len, max_len=200, top_p=1.0, sample_temperature=None): + """ + Decode a sentence given initial start. + `x`: + - LongTensor(bs, slen) + W1 W2 W3 + W1 W2 W3 W4 + `lengths`: + - LongTensor(bs) [5, 6] + `positions`: + - False, for regular "arange" positions (LM) + - True, to reset positions from the new generation (MT) + """ + + # input batch + bs = len(src_len) + assert src_enc.size(0) == bs + + # generated sentences + generated = src_len.new(max_len, bs) # upcoming output + generated.fill_(self.pad_index) # fill upcoming ouput with + generated[0].fill_(self.eos_index) # we use for everywhere + + # positions + positions = src_len.new(max_len).long() + positions = ( + torch.arange(max_len, out=positions).unsqueeze(1).expand(max_len, bs) + ) + + # current position / max lengths / length of generated sentences / unfinished sentences + cur_len = 1 + gen_len = src_len.clone().fill_(1) + unfinished_sents = src_len.clone().fill_(1) + + # cache compute states + self.cache = {"slen": 0} + while cur_len < max_len: + + # compute word scores + tensor = self.forward( + "fwd", + x=generated[:cur_len], + lengths=gen_len, + positions=positions[:cur_len], + causal=True, + src_enc=src_enc, + src_len=src_len, + use_cache=True, + ) + assert tensor.size() == (1, bs, self.dim) + tensor = tensor.data[-1, :, :].to(self.dtype) # (bs, dim) ##BE CAREFUL + scores = self.proj(tensor) # (bs, n_words) + + # select next words: sample or greedy + if sample_temperature is None: + next_words = torch.topk(scores, 1)[1].squeeze(1) + else: + next_words = torch.multinomial( + F.softmax(scores.float() / sample_temperature, dim=1), num_samples=1 + ).squeeze(1) + assert next_words.size() == (bs,) + + # update generations / lengths / finished sentences / current length + generated[cur_len] = next_words * unfinished_sents + self.pad_index * ( + 1 - unfinished_sents + ) + gen_len.add_(unfinished_sents) + unfinished_sents.mul_(next_words.ne(self.eos_index).long()) + cur_len = cur_len + 1 + + # stop when there is a in each sentence, or if we exceed the maximul length + if unfinished_sents.max() == 0: + break + + # add to unfinished sentences + if cur_len == max_len: + generated[-1].masked_fill_(unfinished_sents.byte(), self.eos_index) + # sanity check + assert (generated == self.eos_index).sum() == 2 * bs + generated = generated.unsqueeze(-1).view(generated.shape[0], bs) + return generated[:cur_len], gen_len + + def generate_beam( + self, src_enc, src_len, beam_size, length_penalty, early_stopping, max_len=200, + ): + """ + Decode a sentence given initial start. + `x`: + - LongTensor(bs, slen) + W1 W2 W3 + W1 W2 W3 W4 + `lengths`: + - LongTensor(bs) [5, 6] + `positions`: + - False, for regular "arange" positions (LM) + - True, to reset positions from the new generation (MT) + """ + + # check inputs + assert src_enc.size(0) == src_len.size(0) + assert beam_size >= 1 + # batch size / number of words + bs = len(src_len) + n_words = self.n_words + + + + # expand to beam size the source latent representations / source lengths + src_enc = ( + src_enc.unsqueeze(1) + .expand((bs, beam_size) + src_enc.shape[1:]) + .contiguous() + .view((bs * beam_size,) + src_enc.shape[1:]) + ) + src_len = src_len.unsqueeze(1).expand(bs, beam_size).contiguous().view(-1) + + # generated sentences (batch with beam current hypotheses) + generated = src_len.new(max_len, bs * beam_size) # upcoming output + generated.fill_(self.pad_index) # fill upcoming ouput with + generated[0].fill_(self.eos_index) # we use for everywhere + + # generated hypotheses + generated_hyps = [ + BeamHypotheses(beam_size, max_len, length_penalty, early_stopping) + for _ in range(bs) + ] + + # positions + positions = src_len.new(max_len).long() + positions = ( + torch.arange(max_len, out=positions).unsqueeze(1).expand_as(generated) + ) + + # scores for each sentence in the beam + beam_scores = src_enc.new(bs, beam_size).float().fill_(0) + beam_scores[:, 1:] = -1e9 + beam_scores = beam_scores.view(-1) + + # current position + cur_len = 1 + + # cache compute states + self.cache = {"slen": 0} + + # done sentences + done = [False for _ in range(bs)] + + while cur_len < max_len: + + # compute word scores + tensor = self.forward( + "fwd", + x=generated[:cur_len], + lengths=src_len.new(bs * beam_size).fill_(cur_len), + positions=positions[:cur_len], + causal=True, + src_enc=src_enc, + src_len=src_len, + use_cache=True, + ) + + + assert tensor.size() == (1, bs * beam_size, self.dim) + if self.apex: + tensor = tensor.data[-1, :, :].to(self.dtype) # (bs * beam_size, dim) + else: + tensor = tensor.data[ + -1, :, : + ] # .to(soui elf.dtype) # (bs * beam_size, dim) + scores = self.proj(tensor) # (bs * beam_size, n_words) + scores = F.log_softmax(scores.float(), dim=-1) # (bs * beam_size, n_words) + assert scores.size() == (bs * beam_size, n_words) + + # select next words with scores + _scores = scores + beam_scores[:, None].expand_as( + scores + ) # (bs * beam_size, n_words) + _scores = _scores.view(bs, beam_size * n_words) # (bs, beam_size * n_words) + + next_scores, next_words = torch.topk( + _scores, 2 * beam_size, dim=1, largest=True, sorted=True + ) + assert next_scores.size() == next_words.size() == (bs, 2 * beam_size) + + # next batch beam content + # list of (bs * beam_size) tuple(next hypothesis score, next word, current position in the batch) + next_batch_beam = [] + + # for each sentence + for sent_id in range(bs): + + # if we are done with this sentence + done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done( + next_scores[sent_id].max().item() + ) + if done[sent_id]: + next_batch_beam.extend( + [(0, self.pad_index, 0)] * beam_size + ) # pad the batch + continue + + # next sentence beam content + next_sent_beam = [] + + # next words for this sentence + for idx, value in zip(next_words[sent_id], next_scores[sent_id]): + + # get beam and word IDs + beam_id = torch.div(idx, n_words, rounding_mode='trunc') + word_id = idx % n_words + + # end of sentence, or next word + if word_id == self.eos_index or cur_len + 1 == max_len: + generated_hyps[sent_id].add( + generated[:cur_len, sent_id * beam_size + beam_id] + .clone() + .cpu(), + value.item(), + ) + else: + next_sent_beam.append( + (value, word_id, sent_id * beam_size + beam_id) + ) + + # the beam for next step is full + if len(next_sent_beam) == beam_size: + break + + # update next beam content + assert len(next_sent_beam) == 0 if cur_len + 1 == max_len else beam_size + if len(next_sent_beam) == 0: + next_sent_beam = [ + (0, self.pad_index, 0) + ] * beam_size # pad the batch + next_batch_beam.extend(next_sent_beam) + assert len(next_batch_beam) == beam_size * (sent_id + 1) + + # sanity check / prepare next batch + assert len(next_batch_beam) == bs * beam_size + beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) + beam_words = generated.new([x[1] for x in next_batch_beam]) + beam_idx = src_len.new([x[2] for x in next_batch_beam]) + + # re-order batch and internal states + generated = generated[:, beam_idx] + generated[cur_len] = beam_words + for k in self.cache.keys(): + if k != "slen": + self.cache[k] = ( + self.cache[k][0][beam_idx], + self.cache[k][1][beam_idx], + ) + # update current length + cur_len = cur_len + 1 + + # stop when we are done with each sentence + if all(done): + break + + # def get_coeffs(s): + # roots = [int(s[i + 2]) for i, c in enumerate(s) if c == 'x'] + # poly = np.poly1d(roots, r=True) + # coeffs = list(poly.coefficients.astype(np.int64)) + # return [c % 10 for c in coeffs], coeffs + + # visualize hypotheses + # print([len(x) for x in generated_hyps], cur_len) + # globals().update( locals() ); + # !import code; code.interact(local=vars()) + # for ii in range(bs): + # for ss, ww in sorted(generated_hyps[ii].hyp, key=lambda x: x[0], reverse=True): + # hh = " ".join(self.id2word[x] for x in ww.tolist()) + # print(f"{ss:+.4f} {hh}") + # # cc = get_coeffs(hh[4:]) + # # print(f"{ss:+.4f} {hh} || {cc[0]} || {cc[1]}") + # print("") + + # select the best hypotheses + tgt_len = src_len.new(bs) + best = [] + + for i, hypotheses in enumerate(generated_hyps): + best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1] + tgt_len[i] = len(best_hyp) + 1 # +1 for the symbol + best.append(best_hyp) + + # generate target batch + decoded = src_len.new(tgt_len.max().item(), bs).fill_(self.pad_index) + for i, hypo in enumerate(best): + decoded[: tgt_len[i] - 1, i] = hypo + decoded[tgt_len[i] - 1, i] = self.eos_index + + # sanity check + assert (decoded == self.eos_index).sum() == 2 * bs + + return decoded, tgt_len, generated_hyps + + +class BeamHypotheses(object): + def __init__(self, n_hyp, max_len, length_penalty, early_stopping): + """ + Initialize n-best list of hypotheses. + """ + self.max_len = max_len - 1 # ignoring + self.length_penalty = length_penalty + self.early_stopping = early_stopping + self.n_hyp = n_hyp + self.hyp = [] + self.worst_score = 1e9 + + def __len__(self): + """ + Number of hypotheses in the list. + """ + return len(self.hyp) + + def add(self, hyp, sum_logprobs): + """ + Add a new hypothesis to the list. + """ + score = sum_logprobs / len(hyp) ** self.length_penalty + if len(self) < self.n_hyp or score > self.worst_score: + self.hyp.append((score, hyp)) + if len(self) > self.n_hyp: + sorted_scores = sorted( + [(s, idx) for idx, (s, _) in enumerate(self.hyp)] + ) + del self.hyp[sorted_scores[0][1]] + self.worst_score = sorted_scores[1][0] + else: + self.worst_score = min(score, self.worst_score) + + def is_done(self, best_sum_logprobs): + """ + If there are enough hypotheses and that none of the hypotheses being generated + can become better than the worst one in the heap, + then we are done with this sentence. + """ + if len(self) < self.n_hyp: + return False + elif self.early_stopping: + return True + else: + return ( + self.worst_score + >= best_sum_logprobs / self.max_len ** self.length_penalty + ) + +def top_k_top_p_filtering( + logits: torch.FloatTensor, + top_k: int = 0, + top_p: float = 1.0, + filter_value: float = -float("Inf"), + min_tokens_to_keep: int = 1, +) -> torch.FloatTensor: + """ + Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + Args: + logits: logits distribution shape (batch size, vocabulary size) + top_k (`int`, *optional*, defaults to 0): + If > 0, only keep the top k tokens with highest probability (top-k filtering) + top_p (`float`, *optional*, defaults to 1.0): + If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus + filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + min_tokens_to_keep (`int`, *optional*, defaults to 1): + Minimumber of tokens we keep per batch example in the output. + From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 + """ + if top_k > 0: + logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)( + None, logits + ) + + if 0 <= top_p <= 1.0: + logits = TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=min_tokens_to_keep)(None, logits) + + return logits + +class LogitsWarper: + """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.""" + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + """Torch method for warping logits.""" + raise NotImplementedError( + f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." + ) + + +class TopKLogitsWarper(LogitsWarper): + r""" + [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. + Args: + top_k (`int`): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + filter_value (`float`, *optional*, defaults to `-float("Inf")`): + All filtered values will be set to this float value. + min_tokens_to_keep (`int`, *optional*, defaults to 1): + Minimum number of tokens that cannot be filtered. + """ + + def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): + if not isinstance(top_k, int) or top_k <= 0: + raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}") + + self.top_k = top_k + self.filter_value = filter_value + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.size(-1)) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None] + scores = scores.masked_fill(indices_to_remove, self.filter_value) + return scores + +class TopPLogitsWarper(LogitsWarper): + """ + [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. + Args: + top_p (`float`): + If set to < 1, only the most probable tokens with probabilities that add up to `top_p` or higher are kept + for generation. + filter_value (`float`, *optional*, defaults to `-float("Inf")`): + All filtered values will be set to this float value. + min_tokens_to_keep (`int`, *optional*, defaults to 1): + Minimum number of tokens that cannot be filtered. + """ + + def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): + top_p = float(top_p) + if top_p < 0 or top_p > 1.0: + raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}") + + self.top_p = top_p + self.filter_value = filter_value + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + sorted_logits, sorted_indices = torch.sort(scores, descending=True) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + + # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs > self.top_p + if self.min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., : self.min_tokens_to_keep - 1] = 0 + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + scores = scores.masked_fill(indices_to_remove, self.filter_value) + return scores + diff --git a/symbolicregression/model/utils_wrapper.py b/symbolicregression/model/utils_wrapper.py new file mode 100644 index 0000000..e7ff7a7 --- /dev/null +++ b/symbolicregression/model/utils_wrapper.py @@ -0,0 +1,198 @@ +from abc import ABC, abstractmethod +import sklearn +from scipy.optimize import minimize +import numpy as np +import time +import torch +from functorch import grad +from functools import partial +import traceback + +class TimedFun: + def __init__(self, fun, verbose=False, stop_after=3): + self.fun_in = fun + self.started = False + self.stop_after = stop_after + self.best_fun_value = np.infty + self.best_x = None + self.loss_history=[] + self.verbose = verbose + + def fun(self, x, *args): + if self.started is False: + self.started = time.time() + elif abs(time.time() - self.started) >= self.stop_after: + self.loss_history.append(self.best_fun_value) + raise ValueError("Time is over.") + self.fun_value = self.fun_in(x, *args) + self.loss_history.append(self.fun_value) + if self.best_x is None: + self.best_x=x + elif self.fun_value < self.best_fun_value: + self.best_fun_value=self.fun_value + self.best_x=x + self.x = x + return self.fun_value + +class Scaler(ABC): + """ + Base class for scalers + """ + + def __init__(self): + pass + + @abstractmethod + def fit(self, X): + pass + + @abstractmethod + def fit_transform(self, X): + pass + + @abstractmethod + def transform(self, X): + pass + + @abstractmethod + def get_params(self): + pass + + def rescale_function(self, env, tree, a, b): + prefix = tree.prefix().split(",") + idx = 0 + while idx < len(prefix): + if prefix[idx].startswith("x_"): + k = int(prefix[idx][-1]) + if k>=len(a): + continue + a_k, b_k = str(a[k]), str(b[k]) + prefix_to_add = ["add", b_k, "mul", a_k, prefix[idx]] + prefix = prefix[:idx] + prefix_to_add + prefix[min(idx + 1, len(prefix)):] + idx += len(prefix_to_add) + else: + idx+=1 + continue + rescaled_tree = env.word_to_infix(prefix, is_float=False, str_array=False) + return rescaled_tree + +class StandardScaler(Scaler): + def __init__(self): + """ + transformation is: + x' = (x - mean)/std + """ + self.scaler = sklearn.preprocessing.StandardScaler() + + def fit(self, X): + self.scaler.fit(X) + + def fit_transform(self, X): + scaled_X = self.scaler.fit_transform(X) + return scaled_X + + def transform(self, X): + m, s = self.scaler.mean_, np.sqrt(self.scaler.var_) + return (X-m)/s + + def get_params(self): + m, s = self.scaler.mean_, np.sqrt(self.scaler.var_) + a, b = 1/s, -m/s + return (a, b) + +class MinMaxScaler(Scaler): + def __init__(self): + """ + transformation is: + x' = 2.*(x-xmin)/(xmax-xmin)-1. + """ + self.scaler = sklearn.preprocessing.MinMaxScaler(feature_range=(-1,1)) + + def fit(self, X): + self.scaler.fit(X) + + def fit_transform(self, X): + scaled_X = self.scaler.fit_transform(X) + return scaled_X + + def transform(self, X): + val_min, val_max = self.scaler.data_min_, self.scaler.data_max_ + return 2*(X-val_min)/(val_max-val_min)-1. + + def get_params(self): + val_min, val_max = self.scaler.data_min_, self.scaler.data_max_ + a, b = 2./(val_max-val_min), -1.-2.*val_min/(val_max-val_min) + return (a, b) + +class BFGSRefinement(): + """ + Wrapper around scipy's BFGS solver + """ + + def __init__(self): + """ + Args: + func: a PyTorch function that maps dependent variabels and + parameters to function outputs for all data samples + `func(x, coeffs) -> y` + x, y: problem data as PyTorch tensors. Shape of x is (d, n) and + shape of y is (n,) + """ + super().__init__() + + def go( + self, env, tree, coeffs0, X, y, downsample=-1, stop_after=10 + ): + + func = env.simplifier.tree_to_torch_module(tree, dtype=torch.float64) + self.X, self.y = X, y + if downsample>0: + self.X = self.X[:downsample] + self.y = self.y[:downsample] + self.X=torch.tensor(self.X, dtype=torch.float64, requires_grad=False) + self.y=torch.tensor(self.y, dtype=torch.float64, requires_grad=False) + self.func = partial(func, self.X) + + def objective_torch(coeffs): + """ + Compute the non-linear least-squares objective value + objective(coeffs) = (1/2) sum((y - func(coeffs)) ** 2) + Returns a PyTorch tensor. + """ + if not isinstance(coeffs, torch.Tensor): + coeffs = torch.tensor(coeffs, dtype=torch.float64, requires_grad=True) + y_tilde = self.func(coeffs) + if y_tilde is None: return None + mse = (self.y -y_tilde).pow(2).mean().div(2) + return mse + + def objective_numpy(coeffs): + """ + Return the objective value as a float (for scipy). + """ + return objective_torch(coeffs).item() + + def gradient_numpy(coeffs): + """ + Compute the gradient of the objective at coeffs. + Returns a numpy array (for scipy) + """ + if not isinstance(coeffs, torch.Tensor): + coeffs = torch.tensor(coeffs, dtype=torch.float64, requires_grad=True) + grad_obj = grad(objective_torch)(coeffs) + return grad_obj.detach().numpy() + + objective_numpy_timed = TimedFun(objective_numpy, stop_after=stop_after) + + try: + minimize( + objective_numpy_timed.fun, + coeffs0, + method="BFGS", + jac=gradient_numpy, + options = {"disp": False} + ) + except ValueError as e: + traceback.format_exc() + best_constants = objective_numpy_timed.best_x + return env.wrap_equation_floats(tree, best_constants) \ No newline at end of file diff --git a/symbolicregression/regressors.py b/symbolicregression/regressors.py new file mode 100644 index 0000000..ab103f5 --- /dev/null +++ b/symbolicregression/regressors.py @@ -0,0 +1,184 @@ +from abc import ABC, abstractmethod +import numpy as np +import scipy +try: import xgboost as xgb +except: print("xgb problems") +from sympy import * + +def order_data(X, y): + idx_sorted = np.squeeze(np.argsort(X[:, :1], axis=0), -1) + X = X[idx_sorted] + y = y[idx_sorted] + return X, y + + +def get_infinite_relative_error(prediction, truth): + abs_relative_error = np.abs((prediction - truth) / (truth + 1e-100)) + abs_relative_error = np.nan_to_num(abs_relative_error, nan=np.infty) + return np.max(abs_relative_error) + + +class Regressor(ABC): + def __init__(self, **args): + pass + + @abstractmethod + def fit(self, X, y): + pass + + @abstractmethod + def predict(self, X): + pass + + +class DeepSymbolicRegressor(Regressor): + def __init__(self): + import dso + + self.model = dso.DeepSymbolicRegressor() + self.n_samples = 2000 + + def fit(self, X, y): + if len(y.shape)>1: y = np.squeeze(y,-1) + assert X.shape[0] == y.shape[0] + self.model.fit(X, y, n_samples=self.n_samples) + + def predict(self, X): + return self.model.predict(X) + +class LagrangeRegressor(Regressor): + def __init__(self): + pass + + def fit(self, X, y): + if len(y.shape)>1: y = np.squeeze(y,-1) + + assert X.shape[0] == y.shape[0] + # X,y = order_data(X,y) + X = np.squeeze(X, -1) + self.model = scipy.interpolate.lagrange(X, y) + + def predict(self, X): + if getattr(self, "model", None) is None: + assert False + X = np.squeeze(X, -1) + return self.model(X) + + +class CubicSplineRegressor(Regressor): + def __init__(self): + pass + + def fit(self, X, y): + if len(y.shape)>1: y = np.squeeze(y,-1) + assert X.shape[0] == y.shape[0] + X, y = order_data(X, y) + + X = np.squeeze(X, -1) + diff = np.diff(X) + duplicates = np.concatenate([[False], diff == 0]) + X = X[~duplicates] + y = y[~duplicates] + self.model = scipy.interpolate.CubicSpline(X, y) + + def predict(self, X): + if getattr(self, "model", None) is None: + assert False + X = np.squeeze(X, -1) + return self.model(X) + + +class gplearnSymbolicRegressor(Regressor): + def __init__(self, function_set, const_range): + import gplearn.genetic + + self.admissible_function_set = { + 'add': lambda x, y : x + y, + 'sub': lambda x, y : x - y, + 'mul': lambda x, y : x*y, + 'div': lambda x, y : x/y, + 'sqrt': lambda x : x**0.5, + 'log': lambda x : log(x), + 'abs': lambda x : abs(x), + 'neg': lambda x : -x, + 'inv': lambda x : 1/x, + 'max': lambda x, y : max(x, y), + 'min': lambda x, y : min(x, y), + 'sin': lambda x : sin(x), + 'cos': lambda x : cos(x), + 'pow': lambda x, y : x**y, + } + + self.function_set = function_set.split(",") + self.function_set = list(set(self.function_set) & set(self.admissible_function_set.keys())) + self.const_range = const_range + self.model = gplearn.genetic.SymbolicRegressor( + function_set=self.function_set, const_range=self.const_range + ) + + def get_function(self): + return sympify(str(self.model._program), locals=self.admissible_function_set) + + def fit(self, X, y): + if len(y.shape)>1: y = np.squeeze(y,-1) + assert X.shape[0] == y.shape[0] + self.model.fit(X, y) + + def predict(self, X): + return self.model.predict(X) + + +class LinearRegressor(Regressor): + def __init__(self): + import sklearn.linear_model + + self.model = sklearn.linear_model.LinearRegression() + + def fit(self, X, y): + if len(y.shape)>1: y = np.squeeze(y,-1) + assert X.shape[0] == y.shape[0] + self.model.fit(X, y) + + def predict(self, X): + try: + return self.model.predict(X) + except Exception as e: + print(e, X) + return None + +class MLPRegressor(Regressor): + def __init__(self): + import sklearn.neural_network + + self.model = sklearn.neural_network.MLPRegressor() + + def fit(self, X, y): + if len(y.shape)>1: y = np.squeeze(y,-1) + assert X.shape[0] == y.shape[0] + self.model.fit(X, y) + + def predict(self, X): + try: + return self.model.predict(X) + except Exception as e: + print(e, X) + return None + +class XGBoostRegressor(Regressor): + def __init__(self,**params): + self.params=params + pass + + def fit(self, X, y): + if len(y.shape)>1: y = np.squeeze(y,-1) + assert X.shape[0] == y.shape[0] + self.model = xgb.XGBRegressor(**self.params) + try: + self.model.fit(X,y) + except Exception as e: + self.model = None + + def predict(self, X): + if getattr(self, "model", None) is None: + return None + return self.model.predict(X) diff --git a/symbolicregression/utils.py b/symbolicregression/utils.py new file mode 100644 index 0000000..715a5df --- /dev/null +++ b/symbolicregression/utils.py @@ -0,0 +1,80 @@ +# Copyright (c) 2020-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import os +import re +import sys +import math +import time +import pickle +import random +import getpass +import argparse +import subprocess + +import errno +import signal +from functools import wraps, partial + +FALSY_STRINGS = {"off", "false", "0"} +TRUTHY_STRINGS = {"on", "true", "1"} + +CUDA = True + +def bool_flag(s): + """ + Parse boolean arguments from the command line. + """ + if s.lower() in FALSY_STRINGS: + return False + elif s.lower() in TRUTHY_STRINGS: + return True + else: + raise argparse.ArgumentTypeError("Invalid value for a boolean flag!") + +def to_cuda(*args, use_cpu=False): + """ + Move tensors to CUDA. + """ + if not CUDA or use_cpu: + return args + return [None if x is None else x.cuda() for x in args] + + +class MyTimeoutError(BaseException): + pass + + +def timeout(seconds=10, error_message=os.strerror(errno.ETIME)): + def decorator(func): + def _handle_timeout(repeat_id, signum, frame): + # logger.warning(f"Catched the signal ({repeat_id}) Setting signal handler {repeat_id + 1}") + signal.signal(signal.SIGALRM, partial(_handle_timeout, repeat_id + 1)) + signal.alarm(seconds) + raise MyTimeoutError(error_message) + + def wrapper(*args, **kwargs): + old_signal = signal.signal(signal.SIGALRM, partial(_handle_timeout, 0)) + old_time_left = signal.alarm(seconds) + assert type(old_time_left) is int and old_time_left >= 0 + if 0 < old_time_left < seconds: # do not exceed previous timer + signal.alarm(old_time_left) + start_time = time.time() + try: + result = func(*args, **kwargs) + finally: + if old_time_left == 0: + signal.alarm(0) + else: + sub = time.time() - start_time + signal.signal(signal.SIGALRM, old_signal) + signal.alarm(max(0, math.ceil(old_time_left - sub))) + return result + + return wraps(func)(wrapper) + + return decorator \ No newline at end of file