Skip to content

Commit

Permalink
first push
Browse files Browse the repository at this point in the history
  • Loading branch information
Pierre-Alexandre Kamienny committed Apr 13, 2022
0 parents commit a420897
Show file tree
Hide file tree
Showing 18 changed files with 4,601 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Deep Symbolic Regression
22 changes: 22 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import os
from setuptools import setup, find_packages

# Utility function to read the README file.
# Used for the long_description. It's nice, because now 1) we have a top level
# README file and 2) it's easier to type in the README file than to put a raw
# string in below ...
def read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()

setup(
name = "symbolicregression",
version = "0.0.1",
author = "Pierre-Alexandre Kamienny",
author_email = "[email protected]",
description = ("Performing Symbolic Regression with Transformers"),
license = "BSD",
keywords = "symbolic regression, transformers",
url = "",
packages=find_packages(),
long_description=read('README.md'),
)
1 change: 1 addition & 0 deletions symbolicregression/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import model
34 changes: 34 additions & 0 deletions symbolicregression/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) 2020-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

from logging import getLogger

# from .generators import operators_conv, Node
from .environment import FunctionEnvironment

logger = getLogger()


ENVS = {
"functions": FunctionEnvironment,
}


def build_env(params):
"""
Build environment.
"""
env = ENVS[params.env_name](params)

# tasks
tasks = [x for x in params.tasks.split(",") if len(x) > 0]
assert len(tasks) == len(set(tasks)) > 0
assert all(task in env.TRAINING_TASKS for task in tasks)
params.tasks = tasks
logger.info(f'Training tasks: {", ".join(tasks)}')

return env
230 changes: 230 additions & 0 deletions symbolicregression/envs/encoders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
from abc import ABC, abstractmethod
import numpy as np
import math
from .generators import Node, NodeList
from .utils import *


class Encoder(ABC):
"""
Base class for encoders, encodes and decodes matrices
abstract methods for encoding/decoding numbers
"""

def __init__(self, params):
pass

@abstractmethod
def encode(self, val):
pass

@abstractmethod
def decode(self, lst):
pass

class GeneralEncoder:
def __init__(self, params, symbols, all_operators):
self.float_encoder = FloatSequences(params)
self.equation_encoder = Equation(params, symbols, self.float_encoder, all_operators)

class FloatSequences(Encoder):
def __init__(self, params):
super().__init__(params)
self.float_precision = params.float_precision
self.mantissa_len = params.mantissa_len
self.max_exponent = params.max_exponent
self.base = (self.float_precision + 1) // self.mantissa_len
self.max_token = 10 ** self.base
self.symbols = ["+", "-"]
self.symbols.extend(
["N" + f"%0{self.base}d" % i for i in range(self.max_token)]
)
self.symbols.extend(
["E" + str(i) for i in range(-self.max_exponent, self.max_exponent + 1)]
)

def encode(self, values):
"""
Write a float number
"""
precision = self.float_precision

if len(values.shape) == 1:
seq = []
value = values
for val in value:
assert val not in [-np.inf, np.inf]
sign = "+" if val >= 0 else "-"
m, e = (f"%.{precision}e" % val).split("e")
i, f = m.lstrip("-").split(".")
i = i + f
tokens = chunks(i, self.base)
expon = int(e) - precision
if expon < -self.max_exponent:
tokens = ["0" * self.base] * self.mantissa_len
expon = int(0)
seq.extend([sign, *["N" + token for token in tokens], "E" + str(expon)])
return seq
else:
seqs = [self.encode(values[0])]
N = values.shape[0]
for n in range(1, N):
seqs += [self.encode(values[n])]
return seqs

def decode(self, lst):
"""
Parse a list that starts with a float.
Return the float value, and the position it ends in the list.
"""
if len(lst) == 0:
return None
seq = []
for val in chunks(lst, 2 + self.mantissa_len):
for x in val:
if x[0] not in ["-", "+", "E", "N"]:
return np.nan
try:
sign = 1 if val[0] == "+" else -1
mant = ""
for x in val[1:-1]:
mant += x[1:]
mant = int(mant)
exp = int(val[-1][1:])
value = sign * mant * (10 ** exp)
value = float(value)
except Exception:
value = np.nan
seq.append(value)
return seq


class Equation(Encoder):
def __init__(self, params, symbols, float_encoder, all_operators):
super().__init__(params)
self.params = params
self.max_int = self.params.max_int
self.symbols = symbols
if params.extra_unary_operators != "":
self.extra_unary_operators = self.params.extra_unary_operators.split(",")
else:
self.extra_unary_operators = []
if params.extra_binary_operators != "":
self.extra_binary_operators = self.params.extra_binary_operators.split(",")
else:
self.extra_binary_operators = []
self.float_encoder = float_encoder
self.all_operators=all_operators

def encode(self, tree):
res = []
for elem in tree.prefix().split(","):
try:
val = float(elem)
if elem.lstrip('-').isdigit():
res.extend(self.write_int(int(elem)))
else:
res.extend(self.float_encoder.encode(np.array([val])))
except ValueError:
res.append(elem)
return res

def _decode(self, lst):
if len(lst) == 0:
return None, 0
# elif (lst[0] not in self.symbols) and (not lst[0].lstrip("-").replace(".","").replace("e+", "").replace("e-","").isdigit()):
# return None, 0
elif "OOD" in lst[0]:
return None, 0
elif lst[0] in self.all_operators.keys():
res = Node(lst[0], self.params)
arity = self.all_operators[lst[0]]
pos = 1
for i in range(arity):
child, length = self._decode(lst[pos:])
if child is None:
return None, pos
res.push_child(child)
pos += length
return res, pos
elif lst[0].startswith("INT"):
val, length = self.parse_int(lst)
return Node(str(val), self.params), length
elif lst[0]=="+" or lst[0]=="-":
try:
val = self.float_encoder.decode(lst[:3])[0]
except Exception as e:
#print(e, "error in encoding, lst: {}".format(lst))
return None, 0
return Node(str(val), self.params), 3
elif lst[0].startswith("CONSTANT") or lst[0]=="y": ##added this manually CAREFUL!!
return Node(lst[0], self.params), 1
elif lst[0] in self.symbols:
return Node(lst[0], self.params), 1
else:
try:
float(lst[0]) #if number, return leaf
return Node(lst[0], self.params), 1
except:
return None, 0

def split_at_value(self, lst, value):
indices = [i for i, x in enumerate(lst) if x == value]
res = []
for start, end in zip(
[0, *[i + 1 for i in indices]], [*[i - 1 for i in indices], len(lst)]
):
res.append(lst[start : end + 1])
return res

def decode(self, lst):
trees = []
lists = self.split_at_value(lst, "|")
for lst in lists:
tree = self._decode(lst)[0]
if tree is None:
return None
trees.append(tree)
tree = NodeList(trees)
return tree

def parse_int(self, lst):
"""
Parse a list that starts with an integer.
Return the integer value, and the position it ends in the list.
"""
base = self.max_int
val = 0
i = 0
for x in lst[1:]:
if not (x.rstrip("-").isdigit()):
break
val = val * base + int(x)
i += 1
if base > 0 and lst[0] == "INT-":
val = -val
return val, i + 1

def write_int(self, val):
"""
Convert a decimal integer to a representation in the given base.
"""
if not self.params.use_sympy:
return [str(val)]

base = self.max_int
res = []
max_digit = abs(base)
neg = val < 0
val = -val if neg else val
while True:
rem = val % base
val = val // base
if rem < 0 or rem > max_digit:
rem -= base
val += 1
res.append(str(rem))
if val == 0:
break
res.append("INT-" if neg else "INT+")
return res[::-1]
Loading

0 comments on commit a420897

Please sign in to comment.