Skip to content

Commit

Permalink
update polytype to 7.0
Browse files Browse the repository at this point in the history
  • Loading branch information
lorepozo committed Dec 15, 2023
1 parent ba2d184 commit 5ee3542
Show file tree
Hide file tree
Showing 16 changed files with 177 additions and 180 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ rustdoc-args = ["--html-in-header", "rustdoc-include-katex-header.html"]
crossbeam-channel = "0.5.8"
itertools = "0.11.0"
once_cell = "1.18.0"
polytype = "6.1"
polytype = "7.0.1"
rand = { version = "0.8.5", features = ["small_rng"] }
rayon = "1.8.0"
serde = { version = "1.0", features = ["derive"] }
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ _strings_.
- [x] `impl GP for pcfg::Grammar` is not yet complete.
- [ ] Eta-long sidestepping (so `f` gets enumerated instead of `(λ (f $0))`)
- [ ] Consolidate lazy/non-lazy evaluation (for ergonomics).
- [ ] Permit non-`&'static str`-named `Type`/`TypeSchema`.
- [ ] Permit non-`&'static str`-named `Type`/`TypeScheme`.
- [ ] Ability to include recursive primitives in `lambda` representation.
- [ ] Faster lambda calculus evaluation (less cloning; bubble up whether
beta reduction happened rather than ultimate equality comparison).
Expand Down
8 changes: 4 additions & 4 deletions examples/json_compressor.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use polytype::TypeSchema;
use polytype::TypeScheme;
use programinduction::{lambda, noop_task, ECFrontier};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -58,7 +58,7 @@ struct Solution {
struct CompressionInput {
dsl: lambda::Language,
params: lambda::CompressionParams,
task_types: Vec<TypeSchema>,
task_types: Vec<TypeScheme>,
frontiers: Vec<ECFrontier<lambda::Expression>>,
}
impl From<ExternalCompressionInput> for CompressionInput {
Expand All @@ -69,7 +69,7 @@ impl From<ExternalCompressionInput> for CompressionInput {
.map(|p| {
(
p.name,
TypeSchema::parse(&p.tp).expect("invalid primitive type"),
p.tp.parse::<TypeScheme>().expect("invalid primitive type"),
p.logp,
)
})
Expand Down Expand Up @@ -98,7 +98,7 @@ impl From<ExternalCompressionInput> for CompressionInput {
.frontiers
.into_par_iter()
.map(|f| {
let tp = TypeSchema::parse(&f.task_tp).expect("invalid task type");
let tp = f.task_tp.parse::<TypeScheme>().expect("invalid task type");
let sols = f
.solutions
.into_iter()
Expand Down
8 changes: 4 additions & 4 deletions src/domains/circuits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
//! ```
use itertools::Itertools;
use polytype::{ptp, tp, Type, TypeSchema};
use polytype::{ptp, tp, Type, TypeScheme};
use rand::{
distributions::{Distribution, WeightedIndex},
Rng,
Expand Down Expand Up @@ -191,11 +191,11 @@ pub fn make_tasks_advanced<R: Rng>(
struct CircuitTask {
n_inputs: usize,
expected_outputs: Vec<bool>,
tp: TypeSchema,
tp: TypeScheme,
}
impl CircuitTask {
fn new(n_inputs: usize, expected_outputs: Vec<bool>) -> Self {
let tp = TypeSchema::Monotype(Type::from(vec![tp!(bool); n_inputs + 1]));
let tp = TypeScheme::Monotype(Type::from(vec![tp!(bool); n_inputs + 1]));
CircuitTask {
n_inputs,
expected_outputs,
Expand Down Expand Up @@ -226,7 +226,7 @@ impl Task<[bool]> for CircuitTask {
f64::NEG_INFINITY
}
}
fn tp(&self) -> &TypeSchema {
fn tp(&self) -> &TypeScheme {
&self.tp
}
fn observation(&self) -> &[bool] {
Expand Down
4 changes: 2 additions & 2 deletions src/domains/strings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ static OPERATIONS: Lazy<HashMap<&'static str, Op>> = Lazy::new(|| {

mod gen {
use itertools::Itertools;
use polytype::{ptp, tp, TypeSchema};
use polytype::{ptp, tp, TypeScheme};
use rand::distributions::{Distribution, Uniform};
use rand::{self, Rng};

Expand Down Expand Up @@ -424,7 +424,7 @@ mod gen {
pub fn make_examples<R: Rng>(
rng: &mut R,
n_examples: usize,
) -> Vec<(&'static str, TypeSchema, Vec<(Vec<Space>, Space)>)> {
) -> Vec<(&'static str, TypeScheme, Vec<(Vec<Space>, Space)>)> {
let mut tasks = Vec::new();

macro_rules! t {
Expand Down
8 changes: 4 additions & 4 deletions src/ec.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Representations capable of Exploration-Compression.
use crossbeam_channel::bounded;
use polytype::TypeSchema;
use polytype::TypeScheme;
use rayon::prelude::*;
use std::collections::HashMap;
use std::ops::{Deref, DerefMut};
Expand All @@ -19,7 +19,7 @@ pub struct ECParams {
/// The maximum frontier size; the number of task solutions to be hit before enumeration is
/// stopped for a particular task.
pub frontier_limit: usize,
/// A timeout before enumeration is stopped, run independently per distinct `TypeSchema` being
/// A timeout before enumeration is stopped, run independently per distinct `TypeScheme` being
/// enumerated. If this is reached, there may be fewer than `frontier_limit` many solutions.
pub search_limit_timeout: Option<Duration>,
/// An approximate limit on enumerated description length. If this is reached, there may be
Expand Down Expand Up @@ -95,7 +95,7 @@ pub trait EC<Observation: ?Sized>: Sync + Sized {
/// If it responds with true, enumeration must stop (i.e. this method call should terminate).
///
/// [`Expression`]: #associatedtype.Expression
fn enumerate<F>(&self, tp: TypeSchema, termination_condition: F)
fn enumerate<F>(&self, tp: TypeScheme, termination_condition: F)
where
F: Fn(Self::Expression, f64) -> bool + Sync;
/// Update the representation based on findings of expressions that solve [`Task`]s.
Expand Down Expand Up @@ -296,7 +296,7 @@ pub trait EC<Observation: ?Sized>: Sync + Sized {
fn enumerate_solutions<Observation, L, T>(
repr: &L,
params: &ECParams,
tp: TypeSchema,
tp: TypeScheme,
tasks: Vec<(usize, &T)>,
) -> Vec<(usize, ECFrontier<L::Expression>)>
where
Expand Down
4 changes: 2 additions & 2 deletions src/gp.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Representations capable of Genetic Programming.
use itertools::Itertools;
use polytype::TypeSchema;
use polytype::TypeScheme;
use rand::{distributions::Distribution, distributions::WeightedIndex, seq::SliceRandom, Rng};
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
Expand Down Expand Up @@ -225,7 +225,7 @@ pub trait GP<Observation: ?Sized> {
params: &Self::Params,
rng: &mut R,
pop_size: usize,
tp: &TypeSchema,
tp: &TypeScheme,
) -> Vec<Self::Expression>;

/// Mutate a single program, potentially producing multiple offspring
Expand Down
12 changes: 6 additions & 6 deletions src/lambda/compression.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crossbeam_channel::bounded;
use itertools::Itertools;
use polytype::{Context, Type, TypeSchema};
use polytype::{Context, Type, TypeScheme};
use rayon::join;
use rayon::prelude::*;
use std::borrow::Cow;
Expand Down Expand Up @@ -127,15 +127,15 @@ where
P: Fn(
&I,
&Language,
&[(TypeSchema, Vec<(Expression, f64, f64)>)],
&[(TypeScheme, Vec<(Expression, f64, f64)>)],
&CompressionParams,
&mut Vec<T::Expression>,
) + Sync,
D: Fn(
&I,
&T::Expression,
&mut Language,
&[(TypeSchema, Vec<(Expression, f64, f64)>)],
&[(TypeScheme, Vec<(Expression, f64, f64)>)],
&CompressionParams,
) -> Option<f64>
+ Sync,
Expand All @@ -145,7 +145,7 @@ where
T::Expression,
Expression,
&Language,
&mut Vec<(TypeSchema, Vec<(Expression, f64, f64)>)>,
&mut Vec<(TypeScheme, Vec<(Expression, f64, f64)>)>,
&CompressionParams,
),
{
Expand Down Expand Up @@ -253,7 +253,7 @@ where
}

/// A convenient frontier representation.
pub type RescoredFrontier = (TypeSchema, Vec<(Expression, f64, f64)>);
pub type RescoredFrontier = (TypeScheme, Vec<(Expression, f64, f64)>);

pub fn joint_mdl(dsl: &Language, frontiers: &[RescoredFrontier]) -> f64 {
frontiers
Expand Down Expand Up @@ -408,7 +408,7 @@ impl Language {

/// This is similar to `enumerator::likelihood` but it does a lot more work to determine
/// _outside_ counts.
fn uses(&self, request: &TypeSchema, expr: &Expression) -> (f64, Uses) {
fn uses(&self, request: &TypeScheme, expr: &Expression) -> (f64, Uses) {
let mut ctx = Context::default();
let tp = request.clone().instantiate_owned(&mut ctx);
let env = Rc::new(LinkedList::default());
Expand Down
6 changes: 3 additions & 3 deletions src/lambda/enumerator.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use polytype::{Context, Type, TypeSchema};
use polytype::{Context, Type, TypeScheme};
use std::collections::VecDeque;
use std::rc::Rc;

Expand All @@ -23,7 +23,7 @@ fn budget_interval(n: u32) -> (f64, f64) {
}
}

pub fn run<F>(dsl: &Language, request: TypeSchema, termination_condition: F)
pub fn run<F>(dsl: &Language, request: TypeScheme, termination_condition: F)
where
F: Fn(Expression, f64) -> bool + Sync,
{
Expand Down Expand Up @@ -60,7 +60,7 @@ where
}
}

pub fn likelihood(dsl: &Language, request: &TypeSchema, expr: &Expression) -> f64 {
pub fn likelihood(dsl: &Language, request: &TypeScheme, expr: &Expression) -> f64 {
let mut ctx = Context::default();
let env = Rc::new(LinkedList::default());
let t = request.clone().instantiate_owned(&mut ctx);
Expand Down
16 changes: 8 additions & 8 deletions src/lambda/eval.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! Evaluation happens by calling primitives provided by an evaluator.
use polytype::TypeSchema;
use polytype::TypeScheme;
use std::collections::VecDeque;
use std::sync::Arc;

Expand Down Expand Up @@ -488,7 +488,7 @@ use self::ReducedExpression::*;
#[derive(Clone, PartialEq)]
pub enum ReducedExpression<V> {
Value(V),
Primitive(String, TypeSchema),
Primitive(String, TypeScheme),
Application(Vec<ReducedExpression<V>>),
/// store depth (never zero) for nested abstractions.
Abstraction(usize, Box<ReducedExpression<V>>),
Expand Down Expand Up @@ -856,11 +856,11 @@ where
}
}

fn arity(mut tp: &TypeSchema) -> usize {
fn arity(mut tp: &TypeScheme) -> usize {
let mut tp = loop {
match *tp {
TypeSchema::Monotype(ref t) => break t,
TypeSchema::Polytype { ref body, .. } => tp = body,
TypeScheme::Monotype(ref t) => break t,
TypeScheme::Polytype { ref body, .. } => tp = body,
}
};
let mut count = 0;
Expand All @@ -871,11 +871,11 @@ fn arity(mut tp: &TypeSchema) -> usize {
count
}

fn is_arrow(mut tp: &TypeSchema) -> bool {
fn is_arrow(mut tp: &TypeScheme) -> bool {
loop {
match *tp {
TypeSchema::Monotype(ref t) => break t.as_arrow().is_some(),
TypeSchema::Polytype { ref body, .. } => tp = body,
TypeScheme::Monotype(ref t) => break t.as_arrow().is_some(),
TypeScheme::Polytype { ref body, .. } => tp = body,
}
}
}
28 changes: 14 additions & 14 deletions src/lambda/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub use self::eval::{
pub use self::parser::ParseError;

use crossbeam_channel::bounded;
use polytype::{Context, Type, TypeSchema, UnificationError};
use polytype::{Context, Type, TypeScheme, UnificationError};
use rayon::spawn;
use std::collections::{HashMap, VecDeque};
use std::error::Error;
Expand All @@ -59,8 +59,8 @@ const FREE_VAR_COST: f64 = 0.01;
/// polymorphically-typed lambda calculus with corresponding production log-probabilities.
#[derive(Debug, Clone)]
pub struct Language {
pub primitives: Vec<(String, TypeSchema, f64)>,
pub invented: Vec<(Expression, TypeSchema, f64)>,
pub primitives: Vec<(String, TypeScheme, f64)>,
pub invented: Vec<(Expression, TypeScheme, f64)>,
pub variable_logprob: f64,
/// Symmetry breaking prevents certain productions from being made. Specifically, an item
/// `(f, i, a)` means that enumeration will not yield an application of `f` where the `i`th
Expand All @@ -74,7 +74,7 @@ pub struct Language {
impl Language {
/// A uniform distribution over primitives and invented expressions, as well as the abstraction
/// operation.
pub fn uniform(primitives: Vec<(&str, TypeSchema)>) -> Self {
pub fn uniform(primitives: Vec<(&str, TypeScheme)>) -> Self {
let primitives = primitives
.into_iter()
.map(|(s, t)| (String::from(s), t, 0f64))
Expand Down Expand Up @@ -115,7 +115,7 @@ impl Language {
/// ```
///
/// [`Expression`]: enum.Expression.html
pub fn infer(&self, expr: &Expression) -> Result<TypeSchema, InferenceError> {
pub fn infer(&self, expr: &Expression) -> Result<TypeScheme, InferenceError> {
let mut ctx = Context::default();
let env = VecDeque::new();
let mut indices = HashMap::new();
Expand Down Expand Up @@ -161,7 +161,7 @@ impl Language {
/// ```
///
/// [`add_symmetry_violation`]: #method.add_symmetry_violation
pub fn enumerate(&self, tp: TypeSchema) -> Box<dyn Iterator<Item = (Expression, f64)>> {
pub fn enumerate(&self, tp: TypeScheme) -> Box<dyn Iterator<Item = (Expression, f64)>> {
let (tx, rx) = bounded(1);
let dsl = self.clone();
spawn(move || {
Expand Down Expand Up @@ -328,7 +328,7 @@ impl Language {
/// let expr = dsl.parse("(λ (λ (+ (+ $0 1) $1)))").unwrap();
/// assert_eq!(dsl.likelihood(&req, &expr), -8.317766166719343);
/// ```
pub fn likelihood(&self, request: &TypeSchema, expr: &Expression) -> f64 {
pub fn likelihood(&self, request: &TypeScheme, expr: &Expression) -> f64 {
enumerator::likelihood(self, request, expr)
}

Expand Down Expand Up @@ -583,7 +583,7 @@ impl Language {
impl<Observation: ?Sized> EC<Observation> for Language {
type Expression = Expression;
type Params = CompressionParams;
fn enumerate<F>(&self, tp: TypeSchema, termination_condition: F)
fn enumerate<F>(&self, tp: TypeScheme, termination_condition: F)
where
F: Fn(Expression, f64) -> bool + Sync,
{
Expand Down Expand Up @@ -784,7 +784,7 @@ impl Expression {
*self = new_self;
true
}
fn strip_invented(&self, invented: &[(Expression, TypeSchema, f64)]) -> Expression {
fn strip_invented(&self, invented: &[(Expression, TypeScheme, f64)]) -> Expression {
match *self {
Expression::Application(ref f, ref x) => Expression::Application(
Box::new(f.strip_invented(invented)),
Expand Down Expand Up @@ -926,7 +926,7 @@ impl Expression {
/// ```
pub fn task_by_evaluation<E, V>(
evaluator: E,
tp: TypeSchema,
tp: TypeScheme,
examples: impl AsRef<[(Vec<V>, V)]> + Sync,
) -> impl Task<[(Vec<V>, V)], Representation = Language, Expression = Expression>
where
Expand All @@ -946,7 +946,7 @@ where
/// [`task_by_evaluation`]: fn.task_by_evaluation.html
pub fn task_by_lazy_evaluation<E, V>(
evaluator: E,
tp: TypeSchema,
tp: TypeScheme,
examples: impl AsRef<[(Vec<V>, V)]> + Sync,
) -> impl Task<[(Vec<V>, V)], Representation = Language, Expression = Expression>
where
Expand All @@ -962,7 +962,7 @@ where

struct LambdaTask<const LAZY: bool, E, O: Sync> {
evaluator: Arc<E>,
tp: TypeSchema,
tp: TypeScheme,
examples: O,
}
impl<
Expand All @@ -989,7 +989,7 @@ impl<
f64::NEG_INFINITY
}
}
fn tp(&self) -> &TypeSchema {
fn tp(&self) -> &TypeScheme {
&self.tp
}
fn observation(&self) -> &[(Vec<V>, V)] {
Expand Down Expand Up @@ -1020,7 +1020,7 @@ impl<
f64::NEG_INFINITY
}
}
fn tp(&self) -> &TypeSchema {
fn tp(&self) -> &TypeScheme {
&self.tp
}
fn observation(&self) -> &[(Vec<V>, V)] {
Expand Down
Loading

0 comments on commit 5ee3542

Please sign in to comment.