Skip to content

Commit

Permalink
Fix unicode dot (#376)
Browse files Browse the repository at this point in the history
* The `.` regex should not take the ASCII fast path

see #375 for an example of undefined behavior because of this fast path.

TLDR: the ASCII fast path will stop matching on the first matching byte,
however this would split multi-byte codepoints. Combined with
`Lexer::remaining` (or even just capturing the string like in the issue),
this leads to non-utf8 strings escaping into user code. This is UNSOUND.

* Add tests for unicode dot in both str and bytes

* chore(lib): rewrite using ClassUnicode methods

As suggested by @RustyYato

* Revert "chore(lib): rewrite using ClassUnicode methods"

This reverts commit 80bd23f.

* try: fallback to previous impl

Tests are still passing

* try: add repetition check

* chore(lib): cleanup code

* fix and move

* another fix

---------

Co-authored-by: Jérome Eertmans <[email protected]>
  • Loading branch information
RustyYato and jeertmans authored Feb 16, 2024
1 parent ba69cc3 commit 81f923c
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 15 deletions.
40 changes: 25 additions & 15 deletions logos-codegen/src/graph/regex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::mir::{Class, ClassUnicode, Literal, Mir};

impl<Leaf: Disambiguate + Debug> Graph<Leaf> {
pub fn regex(&mut self, mir: Mir, then: NodeId) -> NodeId {
self.parse_mir(mir, then, None, None)
self.parse_mir(mir, then, None, None, false)
}

fn parse_mir(
Expand All @@ -16,6 +16,7 @@ impl<Leaf: Disambiguate + Debug> Graph<Leaf> {
then: NodeId,
miss: Option<NodeId>,
reserved: Option<ReservedId>,
repeated: bool,
) -> NodeId {
match mir {
Mir::Empty => then,
Expand All @@ -29,21 +30,21 @@ impl<Leaf: Disambiguate + Debug> Graph<Leaf> {
None => self.reserve(),
};

self.parse_mir(*mir, this.get(), Some(miss), Some(this))
self.parse_mir(*mir, this.get(), Some(miss), Some(this), true)
}
Mir::Maybe(mir) => {
let miss = match miss {
Some(id) => self.merge(id, then),
None => then,
};

self.parse_mir(*mir, then, Some(miss), reserved)
self.parse_mir(*mir, then, Some(miss), reserved, true)
}
Mir::Alternation(alternation) => {
let mut fork = Fork::new().miss(miss);

for mir in alternation {
let id = self.parse_mir(mir, then, None, None);
let id = self.parse_mir(mir, then, None, None, repeated);
let alt = self.fork_off(id);

fork.merge(alt, self);
Expand Down Expand Up @@ -73,7 +74,7 @@ impl<Leaf: Disambiguate + Debug> Graph<Leaf> {
}
None
}
Mir::Class(Class::Unicode(class)) if is_one_ascii(&class) => {
Mir::Class(Class::Unicode(class)) if is_one_ascii(&class, repeated) => {
cur -= 1;
ropebuf[cur] = class.ranges()[0].into();
None
Expand All @@ -97,7 +98,7 @@ impl<Leaf: Disambiguate + Debug> Graph<Leaf> {

for mir in concat.drain(1..).rev() {
if let Some(mir) = handle_bytes(self, mir, &mut then) {
then = self.parse_mir(mir, then, None, None);
then = self.parse_mir(mir, then, None, None, false);
}
}

Expand All @@ -107,10 +108,10 @@ impl<Leaf: Disambiguate + Debug> Graph<Leaf> {

self.insert_or_push(reserved, rope)
}
Some(mir) => self.parse_mir(mir, then, miss, reserved),
Some(mir) => self.parse_mir(mir, then, miss, reserved, false),
}
}
Mir::Class(Class::Unicode(class)) if !is_ascii(&class) => {
Mir::Class(Class::Unicode(class)) if !is_ascii(&class, repeated) => {
let mut ropes = class
.iter()
.flat_map(|range| Utf8Sequences::new(range.start(), range.end()))
Expand Down Expand Up @@ -160,25 +161,34 @@ impl<Leaf: Disambiguate + Debug> Graph<Leaf> {
}
}

fn is_ascii(class: &ClassUnicode) -> bool {
class.iter().all(|range| {
/// Return wether current class unicode is ascii.
///
/// Because unicode ranges are iterated in increasing order,
/// it is only necessary to check the last range.
///
/// If the check is performed in a repetition,
/// a fast path is used by checking if end of range is 0x0010_FFFF.
fn is_ascii(class: &ClassUnicode, repeated: bool) -> bool {
class.iter().last().map_or(true, |range| {
let start = range.start() as u32;
let end = range.end() as u32;

start < 128 && (end < 128 || end == 0x0010_FFFF)
end < 128 || (repeated && start < 128 && end == 0x0010_FFFF)
})
}

fn is_one_ascii(class: &ClassUnicode) -> bool {
/// Return wether current class unicode is ascii and only contains
/// one range.
///
/// See [`is_ascii`] function for more details.
fn is_one_ascii(class: &ClassUnicode, repeated: bool) -> bool {
if class.ranges().len() != 1 {
return false;
}

let range = &class.ranges()[0];
let start = range.start() as u32;
let end = range.end() as u32;

start < 128 && (end < 128 || end == 0x0010_FFFF)
end < 128 || (repeated && start < 128 && end == 0x0010_FFFF)
}

#[cfg(test)]
Expand Down
56 changes: 56 additions & 0 deletions tests/tests/unicode_dot.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
use logos::Logos as _;
use logos_derive::Logos;

#[derive(Logos, Debug, PartialEq)]
enum TestUnicodeDot {
#[regex(".")]
Dot,
}

#[test]
fn test_unicode_dot_str_ascii() {
let mut lexer = TestUnicodeDot::lexer("a");
assert_eq!(lexer.next(), Some(Ok(TestUnicodeDot::Dot)));
assert_eq!(lexer.remainder(), "");
assert_eq!(lexer.next(), None);
}

#[test]
fn test_unicode_dot_str_unicode() {
let mut lexer = TestUnicodeDot::lexer("");
assert_eq!(lexer.next(), Some(Ok(TestUnicodeDot::Dot)));
assert_eq!(lexer.remainder(), "");
assert_eq!(lexer.next(), None);
}

#[derive(Logos, Debug, PartialEq)]
enum TestUnicodeDotBytes {
#[regex(".", priority = 100)]
Dot,
#[regex(b".", priority = 0)]
InvalidUtf8,
}

#[test]
fn test_unicode_dot_bytes_ascii() {
let mut lexer = TestUnicodeDotBytes::lexer(b"a");
assert_eq!(lexer.next(), Some(Ok(TestUnicodeDotBytes::Dot)));
assert_eq!(lexer.remainder(), b"");
assert_eq!(lexer.next(), None);
}

#[test]
fn test_unicode_dot_bytes_unicode() {
let mut lexer = TestUnicodeDotBytes::lexer("".as_bytes());
assert_eq!(lexer.next(), Some(Ok(TestUnicodeDotBytes::Dot)));
assert_eq!(lexer.remainder(), b"");
assert_eq!(lexer.next(), None);
}

#[test]
fn test_unicode_dot_bytes_invalid_utf8() {
let mut lexer = TestUnicodeDotBytes::lexer(b"\xff");
assert_eq!(lexer.next(), Some(Ok(TestUnicodeDotBytes::InvalidUtf8)));
assert_eq!(lexer.remainder(), b"");
assert_eq!(lexer.next(), None);
}

0 comments on commit 81f923c

Please sign in to comment.