diff --git a/Cargo.toml b/Cargo.toml index bf35899..65c6adb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,7 @@ rustdoc-args = ["--html-in-header", "rustdoc-include-katex-header.html"] crossbeam-channel = "0.5.8" itertools = "0.11.0" once_cell = "1.18.0" -polytype = "6.1" +polytype = "7.0.1" rand = { version = "0.8.5", features = ["small_rng"] } rayon = "1.8.0" serde = { version = "1.0", features = ["derive"] } diff --git a/README.md b/README.md index e1b9179..56c29d1 100644 --- a/README.md +++ b/README.md @@ -123,7 +123,7 @@ _strings_. - [x] `impl GP for pcfg::Grammar` is not yet complete. - [ ] Eta-long sidestepping (so `f` gets enumerated instead of `(λ (f $0))`) - [ ] Consolidate lazy/non-lazy evaluation (for ergonomics). -- [ ] Permit non-`&'static str`-named `Type`/`TypeSchema`. +- [ ] Permit non-`&'static str`-named `Type`/`TypeScheme`. - [ ] Ability to include recursive primitives in `lambda` representation. - [ ] Faster lambda calculus evaluation (less cloning; bubble up whether beta reduction happened rather than ultimate equality comparison). diff --git a/examples/json_compressor.rs b/examples/json_compressor.rs index 57d840c..90d2dca 100644 --- a/examples/json_compressor.rs +++ b/examples/json_compressor.rs @@ -1,4 +1,4 @@ -use polytype::TypeSchema; +use polytype::TypeScheme; use programinduction::{lambda, noop_task, ECFrontier}; use rayon::prelude::*; use serde::{Deserialize, Serialize}; @@ -58,7 +58,7 @@ struct Solution { struct CompressionInput { dsl: lambda::Language, params: lambda::CompressionParams, - task_types: Vec, + task_types: Vec, frontiers: Vec>, } impl From for CompressionInput { @@ -69,7 +69,7 @@ impl From for CompressionInput { .map(|p| { ( p.name, - TypeSchema::parse(&p.tp).expect("invalid primitive type"), + p.tp.parse::().expect("invalid primitive type"), p.logp, ) }) @@ -98,7 +98,7 @@ impl From for CompressionInput { .frontiers .into_par_iter() .map(|f| { - let tp = TypeSchema::parse(&f.task_tp).expect("invalid task type"); + let tp = f.task_tp.parse::().expect("invalid task type"); let sols = f .solutions .into_iter() diff --git a/src/domains/circuits.rs b/src/domains/circuits.rs index fea6026..d11db72 100644 --- a/src/domains/circuits.rs +++ b/src/domains/circuits.rs @@ -24,7 +24,7 @@ //! ``` use itertools::Itertools; -use polytype::{ptp, tp, Type, TypeSchema}; +use polytype::{ptp, tp, Type, TypeScheme}; use rand::{ distributions::{Distribution, WeightedIndex}, Rng, @@ -191,11 +191,11 @@ pub fn make_tasks_advanced( struct CircuitTask { n_inputs: usize, expected_outputs: Vec, - tp: TypeSchema, + tp: TypeScheme, } impl CircuitTask { fn new(n_inputs: usize, expected_outputs: Vec) -> Self { - let tp = TypeSchema::Monotype(Type::from(vec![tp!(bool); n_inputs + 1])); + let tp = TypeScheme::Monotype(Type::from(vec![tp!(bool); n_inputs + 1])); CircuitTask { n_inputs, expected_outputs, @@ -226,7 +226,7 @@ impl Task<[bool]> for CircuitTask { f64::NEG_INFINITY } } - fn tp(&self) -> &TypeSchema { + fn tp(&self) -> &TypeScheme { &self.tp } fn observation(&self) -> &[bool] { diff --git a/src/domains/strings.rs b/src/domains/strings.rs index 4c82900..50a1da9 100644 --- a/src/domains/strings.rs +++ b/src/domains/strings.rs @@ -370,7 +370,7 @@ static OPERATIONS: Lazy> = Lazy::new(|| { mod gen { use itertools::Itertools; - use polytype::{ptp, tp, TypeSchema}; + use polytype::{ptp, tp, TypeScheme}; use rand::distributions::{Distribution, Uniform}; use rand::{self, Rng}; @@ -424,7 +424,7 @@ mod gen { pub fn make_examples( rng: &mut R, n_examples: usize, - ) -> Vec<(&'static str, TypeSchema, Vec<(Vec, Space)>)> { + ) -> Vec<(&'static str, TypeScheme, Vec<(Vec, Space)>)> { let mut tasks = Vec::new(); macro_rules! t { diff --git a/src/ec.rs b/src/ec.rs index eaf9372..e17a0ee 100644 --- a/src/ec.rs +++ b/src/ec.rs @@ -1,7 +1,7 @@ //! Representations capable of Exploration-Compression. use crossbeam_channel::bounded; -use polytype::TypeSchema; +use polytype::TypeScheme; use rayon::prelude::*; use std::collections::HashMap; use std::ops::{Deref, DerefMut}; @@ -19,7 +19,7 @@ pub struct ECParams { /// The maximum frontier size; the number of task solutions to be hit before enumeration is /// stopped for a particular task. pub frontier_limit: usize, - /// A timeout before enumeration is stopped, run independently per distinct `TypeSchema` being + /// A timeout before enumeration is stopped, run independently per distinct `TypeScheme` being /// enumerated. If this is reached, there may be fewer than `frontier_limit` many solutions. pub search_limit_timeout: Option, /// An approximate limit on enumerated description length. If this is reached, there may be @@ -95,7 +95,7 @@ pub trait EC: Sync + Sized { /// If it responds with true, enumeration must stop (i.e. this method call should terminate). /// /// [`Expression`]: #associatedtype.Expression - fn enumerate(&self, tp: TypeSchema, termination_condition: F) + fn enumerate(&self, tp: TypeScheme, termination_condition: F) where F: Fn(Self::Expression, f64) -> bool + Sync; /// Update the representation based on findings of expressions that solve [`Task`]s. @@ -296,7 +296,7 @@ pub trait EC: Sync + Sized { fn enumerate_solutions( repr: &L, params: &ECParams, - tp: TypeSchema, + tp: TypeScheme, tasks: Vec<(usize, &T)>, ) -> Vec<(usize, ECFrontier)> where diff --git a/src/gp.rs b/src/gp.rs index 877d402..02c439e 100644 --- a/src/gp.rs +++ b/src/gp.rs @@ -1,7 +1,7 @@ //! Representations capable of Genetic Programming. use itertools::Itertools; -use polytype::TypeSchema; +use polytype::TypeScheme; use rand::{distributions::Distribution, distributions::WeightedIndex, seq::SliceRandom, Rng}; use serde::{Deserialize, Serialize}; use std::cmp::Ordering; @@ -225,7 +225,7 @@ pub trait GP { params: &Self::Params, rng: &mut R, pop_size: usize, - tp: &TypeSchema, + tp: &TypeScheme, ) -> Vec; /// Mutate a single program, potentially producing multiple offspring diff --git a/src/lambda/compression.rs b/src/lambda/compression.rs index 4f5e7d2..44f88ae 100644 --- a/src/lambda/compression.rs +++ b/src/lambda/compression.rs @@ -1,6 +1,6 @@ use crossbeam_channel::bounded; use itertools::Itertools; -use polytype::{Context, Type, TypeSchema}; +use polytype::{Context, Type, TypeScheme}; use rayon::join; use rayon::prelude::*; use std::borrow::Cow; @@ -127,7 +127,7 @@ where P: Fn( &I, &Language, - &[(TypeSchema, Vec<(Expression, f64, f64)>)], + &[(TypeScheme, Vec<(Expression, f64, f64)>)], &CompressionParams, &mut Vec, ) + Sync, @@ -135,7 +135,7 @@ where &I, &T::Expression, &mut Language, - &[(TypeSchema, Vec<(Expression, f64, f64)>)], + &[(TypeScheme, Vec<(Expression, f64, f64)>)], &CompressionParams, ) -> Option + Sync, @@ -145,7 +145,7 @@ where T::Expression, Expression, &Language, - &mut Vec<(TypeSchema, Vec<(Expression, f64, f64)>)>, + &mut Vec<(TypeScheme, Vec<(Expression, f64, f64)>)>, &CompressionParams, ), { @@ -253,7 +253,7 @@ where } /// A convenient frontier representation. -pub type RescoredFrontier = (TypeSchema, Vec<(Expression, f64, f64)>); +pub type RescoredFrontier = (TypeScheme, Vec<(Expression, f64, f64)>); pub fn joint_mdl(dsl: &Language, frontiers: &[RescoredFrontier]) -> f64 { frontiers @@ -408,7 +408,7 @@ impl Language { /// This is similar to `enumerator::likelihood` but it does a lot more work to determine /// _outside_ counts. - fn uses(&self, request: &TypeSchema, expr: &Expression) -> (f64, Uses) { + fn uses(&self, request: &TypeScheme, expr: &Expression) -> (f64, Uses) { let mut ctx = Context::default(); let tp = request.clone().instantiate_owned(&mut ctx); let env = Rc::new(LinkedList::default()); diff --git a/src/lambda/enumerator.rs b/src/lambda/enumerator.rs index cc78fea..65e2e91 100644 --- a/src/lambda/enumerator.rs +++ b/src/lambda/enumerator.rs @@ -1,4 +1,4 @@ -use polytype::{Context, Type, TypeSchema}; +use polytype::{Context, Type, TypeScheme}; use std::collections::VecDeque; use std::rc::Rc; @@ -23,7 +23,7 @@ fn budget_interval(n: u32) -> (f64, f64) { } } -pub fn run(dsl: &Language, request: TypeSchema, termination_condition: F) +pub fn run(dsl: &Language, request: TypeScheme, termination_condition: F) where F: Fn(Expression, f64) -> bool + Sync, { @@ -60,7 +60,7 @@ where } } -pub fn likelihood(dsl: &Language, request: &TypeSchema, expr: &Expression) -> f64 { +pub fn likelihood(dsl: &Language, request: &TypeScheme, expr: &Expression) -> f64 { let mut ctx = Context::default(); let env = Rc::new(LinkedList::default()); let t = request.clone().instantiate_owned(&mut ctx); diff --git a/src/lambda/eval.rs b/src/lambda/eval.rs index 68ed1bd..2e8f9a2 100644 --- a/src/lambda/eval.rs +++ b/src/lambda/eval.rs @@ -1,5 +1,5 @@ //! Evaluation happens by calling primitives provided by an evaluator. -use polytype::TypeSchema; +use polytype::TypeScheme; use std::collections::VecDeque; use std::sync::Arc; @@ -488,7 +488,7 @@ use self::ReducedExpression::*; #[derive(Clone, PartialEq)] pub enum ReducedExpression { Value(V), - Primitive(String, TypeSchema), + Primitive(String, TypeScheme), Application(Vec>), /// store depth (never zero) for nested abstractions. Abstraction(usize, Box>), @@ -856,11 +856,11 @@ where } } -fn arity(mut tp: &TypeSchema) -> usize { +fn arity(mut tp: &TypeScheme) -> usize { let mut tp = loop { match *tp { - TypeSchema::Monotype(ref t) => break t, - TypeSchema::Polytype { ref body, .. } => tp = body, + TypeScheme::Monotype(ref t) => break t, + TypeScheme::Polytype { ref body, .. } => tp = body, } }; let mut count = 0; @@ -871,11 +871,11 @@ fn arity(mut tp: &TypeSchema) -> usize { count } -fn is_arrow(mut tp: &TypeSchema) -> bool { +fn is_arrow(mut tp: &TypeScheme) -> bool { loop { match *tp { - TypeSchema::Monotype(ref t) => break t.as_arrow().is_some(), - TypeSchema::Polytype { ref body, .. } => tp = body, + TypeScheme::Monotype(ref t) => break t.as_arrow().is_some(), + TypeScheme::Polytype { ref body, .. } => tp = body, } } } diff --git a/src/lambda/mod.rs b/src/lambda/mod.rs index 241a2a2..1bcc9e1 100644 --- a/src/lambda/mod.rs +++ b/src/lambda/mod.rs @@ -42,7 +42,7 @@ pub use self::eval::{ pub use self::parser::ParseError; use crossbeam_channel::bounded; -use polytype::{Context, Type, TypeSchema, UnificationError}; +use polytype::{Context, Type, TypeScheme, UnificationError}; use rayon::spawn; use std::collections::{HashMap, VecDeque}; use std::error::Error; @@ -59,8 +59,8 @@ const FREE_VAR_COST: f64 = 0.01; /// polymorphically-typed lambda calculus with corresponding production log-probabilities. #[derive(Debug, Clone)] pub struct Language { - pub primitives: Vec<(String, TypeSchema, f64)>, - pub invented: Vec<(Expression, TypeSchema, f64)>, + pub primitives: Vec<(String, TypeScheme, f64)>, + pub invented: Vec<(Expression, TypeScheme, f64)>, pub variable_logprob: f64, /// Symmetry breaking prevents certain productions from being made. Specifically, an item /// `(f, i, a)` means that enumeration will not yield an application of `f` where the `i`th @@ -74,7 +74,7 @@ pub struct Language { impl Language { /// A uniform distribution over primitives and invented expressions, as well as the abstraction /// operation. - pub fn uniform(primitives: Vec<(&str, TypeSchema)>) -> Self { + pub fn uniform(primitives: Vec<(&str, TypeScheme)>) -> Self { let primitives = primitives .into_iter() .map(|(s, t)| (String::from(s), t, 0f64)) @@ -115,7 +115,7 @@ impl Language { /// ``` /// /// [`Expression`]: enum.Expression.html - pub fn infer(&self, expr: &Expression) -> Result { + pub fn infer(&self, expr: &Expression) -> Result { let mut ctx = Context::default(); let env = VecDeque::new(); let mut indices = HashMap::new(); @@ -161,7 +161,7 @@ impl Language { /// ``` /// /// [`add_symmetry_violation`]: #method.add_symmetry_violation - pub fn enumerate(&self, tp: TypeSchema) -> Box> { + pub fn enumerate(&self, tp: TypeScheme) -> Box> { let (tx, rx) = bounded(1); let dsl = self.clone(); spawn(move || { @@ -328,7 +328,7 @@ impl Language { /// let expr = dsl.parse("(λ (λ (+ (+ $0 1) $1)))").unwrap(); /// assert_eq!(dsl.likelihood(&req, &expr), -8.317766166719343); /// ``` - pub fn likelihood(&self, request: &TypeSchema, expr: &Expression) -> f64 { + pub fn likelihood(&self, request: &TypeScheme, expr: &Expression) -> f64 { enumerator::likelihood(self, request, expr) } @@ -583,7 +583,7 @@ impl Language { impl EC for Language { type Expression = Expression; type Params = CompressionParams; - fn enumerate(&self, tp: TypeSchema, termination_condition: F) + fn enumerate(&self, tp: TypeScheme, termination_condition: F) where F: Fn(Expression, f64) -> bool + Sync, { @@ -784,7 +784,7 @@ impl Expression { *self = new_self; true } - fn strip_invented(&self, invented: &[(Expression, TypeSchema, f64)]) -> Expression { + fn strip_invented(&self, invented: &[(Expression, TypeScheme, f64)]) -> Expression { match *self { Expression::Application(ref f, ref x) => Expression::Application( Box::new(f.strip_invented(invented)), @@ -926,7 +926,7 @@ impl Expression { /// ``` pub fn task_by_evaluation( evaluator: E, - tp: TypeSchema, + tp: TypeScheme, examples: impl AsRef<[(Vec, V)]> + Sync, ) -> impl Task<[(Vec, V)], Representation = Language, Expression = Expression> where @@ -946,7 +946,7 @@ where /// [`task_by_evaluation`]: fn.task_by_evaluation.html pub fn task_by_lazy_evaluation( evaluator: E, - tp: TypeSchema, + tp: TypeScheme, examples: impl AsRef<[(Vec, V)]> + Sync, ) -> impl Task<[(Vec, V)], Representation = Language, Expression = Expression> where @@ -962,7 +962,7 @@ where struct LambdaTask { evaluator: Arc, - tp: TypeSchema, + tp: TypeScheme, examples: O, } impl< @@ -989,7 +989,7 @@ impl< f64::NEG_INFINITY } } - fn tp(&self) -> &TypeSchema { + fn tp(&self) -> &TypeScheme { &self.tp } fn observation(&self) -> &[(Vec, V)] { @@ -1020,7 +1020,7 @@ impl< f64::NEG_INFINITY } } - fn tp(&self) -> &TypeSchema { + fn tp(&self) -> &TypeScheme { &self.tp } fn observation(&self) -> &[(Vec, V)] { diff --git a/src/lib.rs b/src/lib.rs index 09cc5c7..e37fc0d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -108,7 +108,7 @@ pub use crate::ec::*; pub use crate::gp::*; use std::marker::PhantomData; -use polytype::TypeSchema; +use polytype::TypeScheme; /// A task which is solved by an expression under some representation. /// @@ -127,7 +127,7 @@ pub trait Task: Sync { fn oracle(&self, dsl: &Self::Representation, expr: &Self::Expression) -> f64; /// An expression that is considered valid for the `oracle` is one of this type. - fn tp(&self) -> &TypeSchema; + fn tp(&self) -> &TypeScheme; /// Some program induction methods can take advantage of observations. This may often /// practically be the [`unit`] type `()`. @@ -138,7 +138,7 @@ pub trait Task: Sync { pub fn noop_task( value: f64, - ptp: TypeSchema, + ptp: TypeScheme, ) -> impl Task<(), Representation = R, Expression = E> { NoopTask { value, @@ -149,7 +149,7 @@ pub fn noop_task( pub fn simple_task( oracle_fn: impl Fn(&R, &E) -> f64 + Sync, - ptp: TypeSchema, + ptp: TypeScheme, ) -> impl Task<(), Representation = R, Expression = E> { SimpleTask { oracle_fn, @@ -160,7 +160,7 @@ pub fn simple_task( struct NoopTask { value: f64, - ptp: TypeSchema, + ptp: TypeScheme, _marker: PhantomData, // using fn to give Send/Sync } impl Task<()> for NoopTask { @@ -169,7 +169,7 @@ impl Task<()> for NoopTask { fn oracle(&self, _dsl: &Self::Representation, _expr: &Self::Expression) -> f64 { self.value } - fn tp(&self) -> &TypeSchema { + fn tp(&self) -> &TypeScheme { &self.ptp } fn observation(&self) -> &() { @@ -179,7 +179,7 @@ impl Task<()> for NoopTask { struct SimpleTask { oracle_fn: F, - ptp: TypeSchema, + ptp: TypeScheme, _marker: PhantomData, // using fn to give Send/Sync } impl Task<()> for SimpleTask @@ -191,7 +191,7 @@ where fn oracle(&self, dsl: &Self::Representation, expr: &Self::Expression) -> f64 { (self.oracle_fn)(dsl, expr) } - fn tp(&self) -> &TypeSchema { + fn tp(&self) -> &TypeScheme { &self.ptp } fn observation(&self) -> &() { diff --git a/src/pcfg/mod.rs b/src/pcfg/mod.rs index a7960e6..4c42c1d 100644 --- a/src/pcfg/mod.rs +++ b/src/pcfg/mod.rs @@ -39,7 +39,7 @@ pub use self::parser::ParseError; use crossbeam_channel::bounded; use itertools::Itertools; -use polytype::{Type, TypeSchema}; +use polytype::{Type, TypeScheme}; use rand::distributions::{Distribution, Uniform}; use rand::seq::SliceRandom; use rand::Rng; @@ -317,12 +317,12 @@ impl EC for Grammar { type Expression = AppliedRule; type Params = EstimationParams; - fn enumerate(&self, tp: TypeSchema, termination_condition: F) + fn enumerate(&self, tp: TypeScheme, termination_condition: F) where F: FnMut(Self::Expression, f64) -> bool, { match tp { - TypeSchema::Monotype(tp) => enumerator::new(self, tp, termination_condition), + TypeScheme::Monotype(tp) => enumerator::new(self, tp, termination_condition), _ => panic!("PCFGs can't handle polytypes"), } } @@ -405,10 +405,10 @@ impl GP<()> for Grammar { _params: &Self::Params, rng: &mut R, pop_size: usize, - tp: &TypeSchema, + tp: &TypeScheme, ) -> Vec { let tp = match *tp { - TypeSchema::Monotype(ref tp) => tp, + TypeScheme::Monotype(ref tp) => tp, _ => panic!("PCFGs can't handle polytypes"), }; (0..pop_size).map(|_| self.sample(tp, rng)).collect() @@ -575,13 +575,13 @@ where PcfgTask { evaluator, output, - tp: TypeSchema::Monotype(tp), + tp: TypeScheme::Monotype(tp), } } struct PcfgTask<'a, V, F> { evaluator: F, output: &'a V, - tp: TypeSchema, + tp: TypeScheme, } impl Task for PcfgTask<'_, V, F> where @@ -602,7 +602,7 @@ where f64::NEG_INFINITY } } - fn tp(&self) -> &TypeSchema { + fn tp(&self) -> &TypeScheme { &self.tp } fn observation(&self) -> &V { diff --git a/src/trs/lexicon.rs b/src/trs/lexicon.rs index cd91237..921302b 100644 --- a/src/trs/lexicon.rs +++ b/src/trs/lexicon.rs @@ -1,5 +1,5 @@ use itertools::{repeat_n, Itertools}; -use polytype::{Context as TypeContext, Type, TypeSchema, Variable as TypeVar}; +use polytype::{Context as TypeContext, Type, TypeScheme, Variable as TypeVar}; use rand::Rng; use serde::{Deserialize, Serialize}; use std::collections::HashMap; @@ -40,7 +40,7 @@ impl Lexicon { /// /// # Example /// - /// See [`polytype::ptp`] for details on constructing [`polytype::TypeSchema`]s. + /// See [`polytype::ptp`] for details on constructing [`polytype::TypeScheme`]s. /// /// ``` /// use polytype::{ptp, tp, Context as TypeContext}; @@ -56,11 +56,11 @@ impl Lexicon { /// let lexicon = Lexicon::new(operators, deterministic, TypeContext::default()); /// ``` /// - /// [`polytype::ptp`]: https://docs.rs/polytype/~6.0/polytype/macro.ptp.html - /// [`polytype::TypeSchema`]: https://docs.rs/polytype/~6.0/polytype/enum.TypeSchema.html + /// [`polytype::ptp`]: https://docs.rs/polytype/~7.0/polytype/macro.ptp.html + /// [`polytype::TypeScheme`]: https://docs.rs/polytype/~7.0/polytype/enum.TypeScheme.html /// [`term_rewriting::Operator`]: https://docs.rs/term_rewriting/~0.3/term_rewriting/struct.Operator.html pub fn new( - operators: Vec<(u32, Option, TypeSchema)>, + operators: Vec<(u32, Option, TypeScheme)>, deterministic: bool, ctx: TypeContext, ) -> Lexicon { @@ -88,7 +88,7 @@ impl Lexicon { /// /// # Example /// - /// See [`polytype::ptp`] for details on constructing [`polytype::TypeSchema`]s. + /// See [`polytype::ptp`] for details on constructing [`polytype::TypeScheme`]s. /// /// ``` /// use polytype::{ptp, tp, Context as TypeContext}; @@ -121,8 +121,8 @@ impl Lexicon { /// let lexicon = Lexicon::from_signature(sig, ops, vars, background, templates, deterministic, TypeContext::default()); /// ``` /// - /// [`polytype::ptp`]: https://docs.rs/polytype/~6.0/polytype/macro.ptp.html - /// [`polytype::TypeSchema`]: https://docs.rs/polytype/~6.0/polytype/enum.TypeSchema.html + /// [`polytype::ptp`]: https://docs.rs/polytype/~7.0/polytype/macro.ptp.html + /// [`polytype::TypeScheme`]: https://docs.rs/polytype/~7.0/polytype/enum.TypeScheme.html /// [`term_rewriting::Signature`]: https://docs.rs/term_rewriting/~0.3/term_rewriting/struct.Signature.html /// [`term_rewriting::Operator`]: https://docs.rs/term_rewriting/~0.3/term_rewriting/struct.Operator.html /// [`term_rewriting::Rule`]: https://docs.rs/term_rewriting/~0.3/term_rewriting/struct.Rule.html @@ -130,8 +130,8 @@ impl Lexicon { /// [`term_rewriting::RuleContext`]: https://docs.rs/term_rewriting/~0.3/term_rewriting/struct.RuleContext.html pub fn from_signature( signature: Signature, - ops: Vec, - vars: Vec, + ops: Vec, + vars: Vec, background: Vec, templates: Vec, deterministic: bool, @@ -161,7 +161,7 @@ impl Lexicon { pub fn context(&self) -> TypeContext { self.0.read().expect("poisoned lexicon").ctx.clone() } - /// Infer the [`polytype::TypeSchema`] associated with a [`term_rewriting::Context`]. + /// Infer the [`polytype::TypeScheme`] associated with a [`term_rewriting::Context`]. /// /// # Example /// @@ -195,46 +195,46 @@ impl Lexicon { /// }; /// let mut ctx = lexicon.context(); /// - /// let inferred_schema = lexicon.infer_context(&context, &mut ctx).unwrap(); + /// let inferred_scheme = lexicon.infer_context(&context, &mut ctx).unwrap(); /// - /// assert_eq!(inferred_schema, ptp![int]); + /// assert_eq!(inferred_scheme, ptp![int]); /// ``` /// - /// [`polytype::TypeSchema`]: https://docs.rs/polytype/~6.0/polytype/enum.TypeSchema.html + /// [`polytype::TypeScheme`]: https://docs.rs/polytype/~7.0/polytype/enum.TypeScheme.html /// [`term_rewriting::Context`]: https://docs.rs/term_rewriting/~0.3/term_rewriting/enum.Context.html pub fn infer_context( &self, context: &Context, ctx: &mut TypeContext, - ) -> Result { + ) -> Result { let lex = self.0.write().expect("poisoned lexicon"); lex.infer_context(context, ctx) } - /// Infer the `TypeSchema` associated with a `RuleContext`. + /// Infer the `TypeScheme` associated with a `RuleContext`. pub fn infer_rulecontext( &self, context: &RuleContext, ctx: &mut TypeContext, - ) -> Result { + ) -> Result { let lex = self.0.write().expect("poisoned lexicon"); lex.infer_rulecontext(context, ctx) } - /// Infer the `TypeSchema` associated with a `Rule`. - pub fn infer_rule(&self, rule: &Rule, ctx: &mut TypeContext) -> Result { + /// Infer the `TypeScheme` associated with a `Rule`. + pub fn infer_rule(&self, rule: &Rule, ctx: &mut TypeContext) -> Result { let lex = self.0.write().expect("poisoned lexicon"); lex.infer_rule(rule, ctx).map(|(r, _, _)| r) } - /// Infer the `TypeSchema` associated with a collection of `Rules`. + /// Infer the `TypeScheme` associated with a collection of `Rules`. pub fn infer_rules( &self, rules: &[Rule], ctx: &mut TypeContext, - ) -> Result { + ) -> Result { let lex = self.0.write().expect("poisoned lexicon"); lex.infer_rules(rules, ctx) } - /// Infer the `TypeSchema` associated with a `Rule`. - pub fn infer_op(&self, op: &Operator) -> Result { + /// Infer the `TypeScheme` associated with a `Rule`. + pub fn infer_op(&self, op: &Operator) -> Result { self.0.write().expect("poisoned lexicon").op_tp(op) } /// Sample a [`term_rewriting::Term`]. @@ -254,7 +254,7 @@ impl Lexicon { /// let deterministic = false; /// let mut lexicon = Lexicon::new(operators, deterministic, TypeContext::default()); /// - /// let schema = ptp![int]; + /// let scheme = ptp![int]; /// let mut ctx = lexicon.context(); /// let invent = true; /// let variable = true; @@ -262,7 +262,7 @@ impl Lexicon { /// let max_size = 50; /// /// let rng = &mut SmallRng::from_seed([1u8; 32]); - /// let term = lexicon.sample_term(rng, &schema, &mut ctx, atom_weights, invent, variable, max_size).unwrap(); + /// let term = lexicon.sample_term(rng, &scheme, &mut ctx, atom_weights, invent, variable, max_size).unwrap(); /// ``` /// /// [`term_rewriting::Term`]: https://docs.rs/term_rewriting/~0.3/term_rewriting/enum.Term.html @@ -270,7 +270,7 @@ impl Lexicon { pub fn sample_term( &mut self, rng: &mut R, - schema: &TypeSchema, + scheme: &TypeScheme, ctx: &mut TypeContext, atom_weights: (f64, f64, f64), invent: bool, @@ -280,7 +280,7 @@ impl Lexicon { let mut lex = self.0.write().expect("poisoned lexicon"); lex.sample_term( rng, - schema, + scheme, ctx, atom_weights, invent, @@ -289,7 +289,7 @@ impl Lexicon { 0, ) } - /// Sample a `Term` conditioned on a `Context` rather than a `TypeSchema`. + /// Sample a `Term` conditioned on a `Context` rather than a `TypeScheme`. #[cfg_attr(feature = "cargo-clippy", allow(clippy::too_many_arguments))] pub fn sample_term_from_context( &mut self, @@ -317,16 +317,16 @@ impl Lexicon { pub fn sample_rule( &mut self, rng: &mut R, - schema: &TypeSchema, + scheme: &TypeScheme, ctx: &mut TypeContext, atom_weights: (f64, f64, f64), invent: bool, max_size: usize, ) -> Result { let mut lex = self.0.write().expect("poisoned lexicon"); - lex.sample_rule(rng, schema, ctx, atom_weights, invent, max_size, 0) + lex.sample_rule(rng, scheme, ctx, atom_weights, invent, max_size, 0) } - /// Sample a `Rule` conditioned on a `Context` rather than a `TypeSchema`. + /// Sample a `Rule` conditioned on a `Context` rather than a `TypeScheme`. pub fn sample_rule_from_context( &mut self, rng: &mut R, @@ -343,38 +343,38 @@ impl Lexicon { pub fn logprior_term( &self, term: &Term, - schema: &TypeSchema, + scheme: &TypeScheme, ctx: &mut TypeContext, atom_weights: (f64, f64, f64), invent: bool, ) -> Result { let lex = self.0.read().expect("posioned lexicon"); - lex.logprior_term(term, schema, ctx, atom_weights, invent) + lex.logprior_term(term, scheme, ctx, atom_weights, invent) } /// Give the log probability of sampling a Rule. pub fn logprior_rule( &self, rule: &Rule, - schema: &TypeSchema, + scheme: &TypeScheme, ctx: &mut TypeContext, atom_weights: (f64, f64, f64), invent: bool, ) -> Result { let lex = self.0.read().expect("poisoned lexicon"); - lex.logprior_rule(rule, schema, ctx, atom_weights, invent) + lex.logprior_rule(rule, scheme, ctx, atom_weights, invent) } /// Give the log probability of sampling a TRS. pub fn logprior_utrs( &self, utrs: &UntypedTRS, - schemas: &[TypeSchema], + schemes: &[TypeScheme], p_rule: f64, ctx: &mut TypeContext, atom_weights: (f64, f64, f64), invent: bool, ) -> Result { let lex = self.0.read().expect("poisoned lexicon"); - lex.logprior_utrs(utrs, schemas, p_rule, ctx, atom_weights, invent) + lex.logprior_utrs(utrs, schemes, p_rule, ctx, atom_weights, invent) } /// merge two `TRS` into a single `TRS`. @@ -417,8 +417,8 @@ impl fmt::Display for Lexicon { #[derive(Debug, Clone)] pub(crate) struct Lex { - pub(crate) ops: Vec, - pub(crate) vars: Vec, + pub(crate) ops: Vec, + pub(crate) vars: Vec, pub(crate) signature: Signature, pub(crate) background: Vec, /// Rule templates to use when sampling rules. @@ -430,11 +430,11 @@ pub(crate) struct Lex { impl fmt::Display for Lex { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { writeln!(f, "Signature:")?; - for (op, schema) in self.signature.operators().iter().zip(&self.ops) { - writeln!(f, "{}: {}", op.display(), schema)?; + for (op, scheme) in self.signature.operators().iter().zip(&self.ops) { + writeln!(f, "{}: {}", op.display(), scheme)?; } - for (var, schema) in self.signature.variables().iter().zip(&self.vars) { - writeln!(f, "{}: {}", var.display(), schema)?; + for (var, scheme) in self.signature.variables().iter().zip(&self.vars) { + writeln!(f, "{}: {}", var.display(), scheme)?; } writeln!(f, "\nBackground: {}", self.background.len())?; for rule in &self.background { @@ -449,8 +449,8 @@ impl fmt::Display for Lex { } impl Lex { fn free_vars(&self) -> Vec { - let vars_fvs = self.vars.iter().flat_map(TypeSchema::free_vars); - let ops_fvs = self.ops.iter().flat_map(TypeSchema::free_vars); + let vars_fvs = self.vars.iter().flat_map(TypeScheme::free_vars); + let ops_fvs = self.ops.iter().flat_map(TypeScheme::free_vars); vars_fvs.chain(ops_fvs).unique().collect() } fn free_vars_applied(&self, ctx: &TypeContext) -> Vec { @@ -462,7 +462,7 @@ impl Lex { } fn invent_variable(&mut self, tp: &Type) -> Variable { let var = self.signature.new_var(None); - self.vars.push(TypeSchema::Monotype(tp.clone())); + self.vars.push(TypeScheme::Monotype(tp.clone())); var } fn fit_atom( @@ -503,11 +503,11 @@ impl Lex { let can_be_variable = true; for arg_tp in arg_types { let subtype = arg_tp.apply(ctx); - let arg_schema = TypeSchema::Monotype(arg_tp); + let arg_scheme = TypeScheme::Monotype(arg_tp); let result = self .sample_term_internal( rng, - &arg_schema, + &arg_scheme, ctx, atom_weights, invent, @@ -545,14 +545,14 @@ impl Lex { tp.apply_mut(ctx); Ok(tp) } - fn var_tp(&self, v: &Variable) -> Result { + fn var_tp(&self, v: &Variable) -> Result { if let Some(idx) = self.signature.variables().iter().position(|x| x == v) { Ok(self.vars[idx].clone()) } else { Err(TypeError::VarNotFound) } } - fn op_tp(&self, o: &Operator) -> Result { + fn op_tp(&self, o: &Operator) -> Result { if let Some(idx) = self.signature.operators().iter().position(|x| x == o) { Ok(self.ops[idx].clone()) } else { @@ -560,13 +560,13 @@ impl Lex { } } - fn infer_atom(&self, atom: &Atom) -> Result { + fn infer_atom(&self, atom: &Atom) -> Result { match *atom { Atom::Operator(ref o) => self.op_tp(o), Atom::Variable(ref v) => self.var_tp(v), } } - pub fn infer_term(&self, term: &Term, ctx: &mut TypeContext) -> Result { + pub fn infer_term(&self, term: &Term, ctx: &mut TypeContext) -> Result { let tp = self.infer_term_internal(term, ctx)?; let lex_vars = self.free_vars_applied(ctx); Ok(tp.apply(ctx).generalize(&lex_vars)) @@ -593,7 +593,7 @@ impl Lex { &self, context: &Context, ctx: &mut TypeContext, - ) -> Result { + ) -> Result { let tp = self.infer_context_internal(context, ctx, vec![], &mut HashMap::new())?; let lex_vars = self.free_vars_applied(ctx); Ok(tp.apply(ctx).generalize(&lex_vars)) @@ -631,28 +631,28 @@ impl Lex { &self, r: &Rule, ctx: &mut TypeContext, - ) -> Result<(TypeSchema, TypeSchema, Vec), TypeError> { - let lhs_schema = self.infer_term(&r.lhs, ctx)?; - let lhs_type = lhs_schema.instantiate(ctx); + ) -> Result<(TypeScheme, TypeScheme, Vec), TypeError> { + let lhs_scheme = self.infer_term(&r.lhs, ctx)?; + let lhs_type = lhs_scheme.instantiate(ctx); let mut rhs_types = Vec::with_capacity(r.rhs.len()); - let mut rhs_schemas = Vec::with_capacity(r.rhs.len()); + let mut rhs_schemes = Vec::with_capacity(r.rhs.len()); for rhs in &r.rhs { - let rhs_schema = self.infer_term(rhs, ctx)?; - rhs_types.push(rhs_schema.instantiate(ctx)); - rhs_schemas.push(rhs_schema); + let rhs_scheme = self.infer_term(rhs, ctx)?; + rhs_types.push(rhs_scheme.instantiate(ctx)); + rhs_schemes.push(rhs_scheme); } for rhs_type in rhs_types { ctx.unify(&lhs_type, &rhs_type)?; } let lex_vars = self.free_vars_applied(ctx); - let rule_schema = lhs_type.apply(ctx).generalize(&lex_vars); - Ok((rule_schema, lhs_schema, rhs_schemas)) + let rule_scheme = lhs_type.apply(ctx).generalize(&lex_vars); + Ok((rule_scheme, lhs_scheme, rhs_schemes)) } pub fn infer_rules( &self, rules: &[Rule], ctx: &mut TypeContext, - ) -> Result { + ) -> Result { let tp = ctx.new_variable(); let mut rule_tps = vec![]; for rule in rules.iter() { @@ -669,7 +669,7 @@ impl Lex { &self, context: &RuleContext, ctx: &mut TypeContext, - ) -> Result { + ) -> Result { let tp = self.infer_rulecontext_internal(context, ctx, &mut HashMap::new())?; let lex_vars = self.free_vars_applied(ctx); Ok(tp.apply(ctx).generalize(&lex_vars)) @@ -705,7 +705,7 @@ impl Lex { pub fn sample_term( &mut self, rng: &mut R, - schema: &TypeSchema, + scheme: &TypeScheme, ctx: &mut TypeContext, atom_weights: (f64, f64, f64), invent: bool, @@ -715,7 +715,7 @@ impl Lex { ) -> Result { self.sample_term_internal( rng, - schema, + scheme, ctx, atom_weights, invent, @@ -729,7 +729,7 @@ impl Lex { pub fn sample_term_internal( &mut self, rng: &mut R, - schema: &TypeSchema, + scheme: &TypeScheme, ctx: &mut TypeContext, atom_weights: (f64, f64, f64), invent: bool, @@ -741,7 +741,7 @@ impl Lex { if size >= max_size { return Err(SampleError::SizeExceeded(size, max_size)); } - let tp = schema.instantiate(ctx); + let tp = scheme.instantiate(ctx); let (atom, arg_types) = self.prepare_option(rng, vars, atom_weights, invent, variable, &tp, ctx)?; self.place_atom( @@ -818,10 +818,10 @@ impl Lex { let lex_vars = self.free_vars_applied(ctx); let mut context_vars = context.variables(); for p in &hole_places { - let schema = &map[p].apply(ctx).generalize(&lex_vars); + let scheme = &map[p].apply(ctx).generalize(&lex_vars); let subterm = self.sample_term_internal( rng, - schema, + scheme, ctx, atom_weights, invent, @@ -838,7 +838,7 @@ impl Lex { pub fn sample_rule( &mut self, rng: &mut R, - schema: &TypeSchema, + scheme: &TypeScheme, ctx: &mut TypeContext, atom_weights: (f64, f64, f64), invent: bool, @@ -851,7 +851,7 @@ impl Lex { let mut vars = vec![]; let lhs = self.sample_term_internal( rng, - schema, + scheme, ctx, atom_weights, invent, @@ -862,7 +862,7 @@ impl Lex { )?; let rhs = self.sample_term_internal( rng, - schema, + scheme, ctx, atom_weights, false, @@ -895,12 +895,12 @@ impl Lex { let mut context_vars = context.variables(); self.infer_rulecontext_internal(&context, ctx, &mut map)?; for p in &hole_places { - let schema = TypeSchema::Monotype(map[p].apply(ctx)); + let scheme = TypeScheme::Monotype(map[p].apply(ctx)); let can_invent = p[0] == 0 && invent; let can_be_variable = p == &vec![0]; let subterm = self.sample_term_internal( rng, - &schema, + &scheme, ctx, atom_weights, can_invent, @@ -919,13 +919,13 @@ impl Lex { fn logprior_term( &self, term: &Term, - schema: &TypeSchema, + scheme: &TypeScheme, ctx: &mut TypeContext, (vw, cw, ow): (f64, f64, f64), invent: bool, ) -> Result { - // instantiate the typeschema - let tp = schema.instantiate(ctx); + // instantiate the typescheme + let tp = scheme.instantiate(ctx); // setup the existing options let (mut vs, mut cs, mut os) = (vec![], vec![], vec![]); let atoms = self.signature.atoms(); @@ -970,8 +970,8 @@ impl Lex { let mut lp = olp; for (subterm, mut arg_tp) in term.args().iter().zip(arg_types) { arg_tp.apply_mut(ctx); - let arg_schema = TypeSchema::Monotype(arg_tp.clone()); - lp += self.logprior_term(subterm, &arg_schema, ctx, (vw, cw, ow), invent)?; + let arg_scheme = TypeScheme::Monotype(arg_tp.clone()); + lp += self.logprior_term(subterm, &arg_scheme, ctx, (vw, cw, ow), invent)?; let final_type = self.infer_term(subterm, ctx)?.instantiate_owned(ctx); if ctx.unify(&arg_tp, &final_type).is_err() { return Ok(f64::NEG_INFINITY); @@ -985,15 +985,15 @@ impl Lex { fn logprior_rule( &self, rule: &Rule, - schema: &TypeSchema, + scheme: &TypeScheme, ctx: &mut TypeContext, atom_weights: (f64, f64, f64), invent: bool, ) -> Result { let mut lp = 0.0; - let lp_lhs = self.logprior_term(&rule.lhs, schema, ctx, atom_weights, invent)?; + let lp_lhs = self.logprior_term(&rule.lhs, scheme, ctx, atom_weights, invent)?; for rhs in &rule.rhs { - let tmp_lp = self.logprior_term(rhs, schema, ctx, atom_weights, false)?; + let tmp_lp = self.logprior_term(rhs, scheme, ctx, atom_weights, false)?; lp += tmp_lp + lp_lhs; } Ok(lp) @@ -1001,7 +1001,7 @@ impl Lex { fn logprior_utrs( &self, utrs: &UntypedTRS, - schemas: &[TypeSchema], + schemes: &[TypeScheme], p_rule: f64, ctx: &mut TypeContext, atom_weights: (f64, f64, f64), @@ -1018,8 +1018,8 @@ impl Lex { continue; } let mut rule_ps = vec![]; - for schema in schemas { - let tmp_lp = self.logprior_rule(rule, schema, ctx, atom_weights, invent)?; + for scheme in schemes { + let tmp_lp = self.logprior_rule(rule, scheme, ctx, atom_weights, invent)?; rule_ps.push(tmp_lp); } p_rules += logsumexp(&rule_ps); @@ -1036,7 +1036,7 @@ impl GP<[Rule]> for Lexicon { params: &Self::Params, rng: &mut R, pop_size: usize, - _tp: &TypeSchema, + _tp: &TypeScheme, ) -> Vec { let trs = TRS::new( self, diff --git a/src/trs/mod.rs b/src/trs/mod.rs index 3a1b1d2..4c4f7ac 100644 --- a/src/trs/mod.rs +++ b/src/trs/mod.rs @@ -52,7 +52,7 @@ pub use self::rewrite::TRS; use crate::Task; use polytype; -use polytype::TypeSchema; +use polytype::TypeScheme; use serde::{Deserialize, Serialize}; use std::fmt; use term_rewriting::{Rule, TRSError}; @@ -184,7 +184,7 @@ pub fn task_by_rewrite<'a, O: Sync + 'a>( struct TrsTask<'a, O> { data: &'a [Rule], params: ModelParams, - tp: TypeSchema, + tp: TypeScheme, observation: O, } impl<'a, O: Sync> Task for TrsTask<'a, O> { @@ -194,7 +194,7 @@ impl<'a, O: Sync> Task for TrsTask<'a, O> { fn oracle(&self, _: &Lexicon, h: &TRS) -> f64 { -h.posterior(self.data, self.params) } - fn tp(&self) -> &TypeSchema { + fn tp(&self) -> &TypeScheme { &self.tp } fn observation(&self) -> &O { diff --git a/src/trs/parser.rs b/src/trs/parser.rs index 71b8614..13d8c0d 100644 --- a/src/trs/parser.rs +++ b/src/trs/parser.rs @@ -1,6 +1,6 @@ use super::lexicon::Lexicon; use super::rewrite::TRS; -use polytype::{Context as TypeContext, TypeSchema}; +use polytype::{Context as TypeContext, TypeScheme}; use std::fmt; use std::io; use term_rewriting::{ @@ -52,13 +52,13 @@ impl ::std::error::Error for ParseError { /// # Lexicon syntax /// /// `input` is parsed as a `lexicon`, defined below in [augmented Backus-Naur -/// form]. The definition of `schema` is as given in [`polytype`], while other +/// form]. The definition of `scheme` is as given in [`polytype`], while other /// terms are as given in [`term_rewriting`]: /// /// ```text /// lexicon = *wsp *( *comment declaration ";" *comment ) *wsp /// -/// declaration = *wsp identifier *wsp ":" *wsp schema *wsp +/// declaration = *wsp identifier *wsp ":" *wsp scheme *wsp /// ``` /// /// # Background syntax @@ -169,23 +169,23 @@ enum AtomName { fn make_atom( name: AtomName, sig: &mut Signature, - schema: TypeSchema, - vars: &mut Vec, - ops: &mut Vec, + scheme: TypeScheme, + vars: &mut Vec, + ops: &mut Vec, ) -> Atom { match name { AtomName::Variable(s) => { let v = sig.new_var(Some(s)); - vars.push(schema); + vars.push(scheme); Atom::Variable(v) } AtomName::Operator(s) => { - let arity = schema + let arity = scheme .instantiate(&mut TypeContext::default()) .args() .map_or(0, |args| args.len()); let o = sig.new_op(arity as u32, Some(s)); - ops.push(schema); + ops.push(scheme); Atom::Operator(o) } } @@ -216,37 +216,34 @@ fn parse_irrelevant(input: &mut &str) -> PResult<()> { repeat(0.., (parse_comment, multispace0)).parse_next(input)?; Ok(()) } -fn parse_declaration(input: &mut &str) -> PResult<(AtomName, TypeSchema)> { - let (atom_name, schema_txt) = delimited( +fn parse_declaration(input: &mut &str) -> PResult<(AtomName, TypeScheme)> { + delimited( multispace0, separated_pair( parse_atom_name, (multispace0, ':', multispace0), - terminated(take_till(0.., ';'), ';'), + terminated(take_till(0.., ';'), ';').parse_to(), ), multispace0, ) - .parse_next(input)?; - let schema = - TypeSchema::parse(schema_txt).map_err(|_| ErrMode::Backtrack(Default::default()))?; - Ok((atom_name, schema)) + .parse_next(input) } fn parse_simple_lexicon( input: &mut &str, deterministic: bool, ctx: TypeContext, ) -> PResult { - let atom_infos: Vec<(AtomName, TypeSchema)> = repeat( + let atom_infos: Vec<(AtomName, TypeScheme)> = repeat( 0.., delimited(parse_irrelevant, parse_declaration, parse_irrelevant), ) .parse_next(input)?; let mut sig = Signature::default(); - let mut vars: Vec = vec![]; - let mut ops: Vec = vec![]; - for (atom_name, schema) in atom_infos { - make_atom(atom_name, &mut sig, schema, &mut vars, &mut ops); + let mut vars: Vec = vec![]; + let mut ops: Vec = vec![]; + for (atom_name, scheme) in atom_infos { + make_atom(atom_name, &mut sig, scheme, &mut vars, &mut ops); } Ok(Lexicon::from_signature( sig, @@ -342,11 +339,11 @@ fn _parse_templates( fn add_parsed_variables_to_lexicon(lex: &Lexicon, ctx: &mut TypeContext) { let n_vars = lex.0.read().unwrap().signature.variables().len(); - let n_schemas = lex.0.read().unwrap().vars.len(); - let diff = n_vars - n_schemas; + let n_schemes = lex.0.read().unwrap().vars.len(); + let diff = n_vars - n_schemes; for _ in 0..diff { - let schema = TypeSchema::Monotype(ctx.new_variable()); - lex.0.write().unwrap().vars.push(schema); + let scheme = TypeScheme::Monotype(ctx.new_variable()); + lex.0.write().unwrap().vars.push(scheme); } }