diff --git a/lua/colorizer.lua b/lua/colorizer.lua index f127b19..382636d 100644 --- a/lua/colorizer.lua +++ b/lua/colorizer.lua @@ -11,20 +11,34 @@ local nvim_buf_get_lines = vim.api.nvim_buf_get_lines local nvim_get_current_buf = vim.api.nvim_get_current_buf local band, lshift, bor, tohex = bit.band, bit.lshift, bit.bor, bit.tohex local rshift = bit.rshift -local floor = math.floor +local floor, min, max = math.floor, math.min, math.max local COLOR_MAP local COLOR_TRIE +local COLOR_NAME_MINLEN, COLOR_NAME_MAXLEN +local COLOR_NAME_SETTINGS = { + lowercase = true; + strip_digits = true; +} --- Setup the COLOR_MAP and COLOR_TRIE local function initialize_trie() if not COLOR_TRIE then - COLOR_MAP = nvim.get_color_map() + COLOR_MAP = {} COLOR_TRIE = Trie() - - for k, v in pairs(COLOR_MAP) do - COLOR_MAP[k] = tohex(v, 6) - COLOR_TRIE:insert(k) + for k, v in pairs(nvim.get_color_map()) do + if not (COLOR_NAME_SETTINGS.strip_digits and k:match("%d+$")) then + COLOR_NAME_MINLEN = COLOR_NAME_MINLEN and min(#k, COLOR_NAME_MINLEN) or #k + COLOR_NAME_MAXLEN = COLOR_NAME_MAXLEN and max(#k, COLOR_NAME_MAXLEN) or #k + local rgb_hex = tohex(v, 6) + COLOR_MAP[k] = rgb_hex + COLOR_TRIE:insert(k) + if COLOR_NAME_SETTINGS.lowercase then + local lowercase = k:lower() + COLOR_MAP[lowercase] = rgb_hex + COLOR_TRIE:insert(lowercase) + end + end end end end @@ -41,18 +55,28 @@ local function merge(...) end local DEFAULT_OPTIONS = { - RGB = true; -- #RGB hex codes - RRGGBB = true; -- #RRGGBB hex codes - names = true; -- "Name" codes like Blue - RRGGBBAA = false; -- #RRGGBBAA hex codes - rgb_fn = false; -- CSS rgb() and rgba() functions - hsl_fn = false; -- CSS hsl() and hsla() functions - css = false; -- Enable all CSS features: rgb_fn, hsl_fn, names, RGB, RRGGBB - css_fn = false; -- Enable all CSS *functions*: rgb_fn, hsl_fn + RGB = true; -- #RGB hex codes + RRGGBB = true; -- #RRGGBB hex codes + names = true; -- "Name" codes like Blue + RRGGBBAA = false; -- #RRGGBBAA hex codes + rgb_fn = false; -- CSS rgb() and rgba() functions + hsl_fn = false; -- CSS hsl() and hsla() functions + css = false; -- Enable all CSS features: rgb_fn, hsl_fn, names, RGB, RRGGBB + css_fn = false; -- Enable all CSS *functions*: rgb_fn, hsl_fn + lowercase = false; -- Enable lowercase "Name" codes -- Available modes: foreground, background mode = 'background'; -- Set the display mode. } +-- -- TODO use rgb as the return value from the matcher functions +-- -- instead of the rgb_hex. Can be the highlight key as well +-- -- when you shift it left 8 bits. Use the lower 8 bits for +-- -- indicating which highlight mode to use. +-- ffi.cdef [[ +-- typedef struct { uint8_t r, g, b; } colorizer_rgb; +-- ]] +-- local rgb_t = ffi.typeof 'colorizer_rgb' + -- Create a lookup table where the bottom 4 bits are used to indicate the -- category and the top 4 bits are the hex value of the ASCII byte. local BYTE_CATEGORY = ffi.new 'uint8_t[256]' @@ -86,12 +110,12 @@ end local function byte_is_hex(byte) return band(BYTE_CATEGORY[byte], CATEGORY_HEX) ~= 0 end - +local function byte_is_alpha(byte) + return band(BYTE_CATEGORY[byte], CATEGORY_ALPHA) ~= 0 +end local function byte_is_alphanumeric(byte) - local category = BYTE_CATEGORY[byte] - return band(category, CATEGORY_ALPHANUM) ~= 0 + return band(BYTE_CATEGORY[byte], CATEGORY_ALPHANUM) ~= 0 end - local function parse_hex(b) return rshift(BYTE_CATEGORY[b], 4) end @@ -144,15 +168,21 @@ local function hsl_to_rgb(h, s, l) return 255*hue_to_rgb(p, q, h + 1/3), 255*hue_to_rgb(p, q, h), 255*hue_to_rgb(p, q, h - 1/3) end -local function name_parser(line, i) +local function color_name_parser(line, i, allow_lowercase) + -- Disallow prefixing with an alphanumeric character if i > 1 and byte_is_alphanumeric(line:byte(i-1)) then return end - local prefix = COLOR_TRIE:longest_prefix(line:sub(i)) + if #line < i + COLOR_NAME_MINLEN - 1 then return end + if not allow_lowercase then + local b = line:byte(i) + -- This means it's lowercase. + if byte_is_alpha(b) and b >= 0x61 then return end + end + local prefix = COLOR_TRIE:longest_prefix(line, i) if prefix then - -- Check if there is a letter here so as to disallow matching here. + -- Disallow trailing alphanumeric characters. -- Take the Blue out of Blueberry - -- Line end or non-letter. local next_byte_index = i + #prefix if #line >= next_byte_index and byte_is_alphanumeric(line:byte(next_byte_index)) then return @@ -174,7 +204,7 @@ local function rgb_hex_parser(line, i, minlen, maxlen) local n = j + maxlen local alpha local v = 0 - while j <= math.min(n, #line) do + while j <= min(n, #line) do local b = line:byte(j) if not byte_is_hex(b) then break end if j - i >= 7 then @@ -205,23 +235,22 @@ end -- Things like pumblend might be useful here. local css_fn = {} do - local css_rgb_fn_minimum_length = #'rgb(0,0,0)' - 1 - local css_rgba_fn_minimum_length = #'rgba(0,0,0,0)' - 1 - local css_hsl_fn_minimum_length = #'hsl(0,0%,0%)' - 1 - local css_hsla_fn_minimum_length = #'hsla(0,0%,0%,0)' - 1 + local CSS_RGB_FN_MINIMUM_LENGTH = #'rgb(0,0,0)' - 1 + local CSS_RGBA_FN_MINIMUM_LENGTH = #'rgba(0,0,0,0)' - 1 + local CSS_HSL_FN_MINIMUM_LENGTH = #'hsl(0,0%,0%)' - 1 + local CSS_HSLA_FN_MINIMUM_LENGTH = #'hsla(0,0%,0%,0)' - 1 function css_fn.rgb(line, i) - if #line < i + css_rgb_fn_minimum_length then return end + if #line < i + CSS_RGB_FN_MINIMUM_LENGTH then return end local r, g, b, match_end = line:sub(i):match("^rgb%(%s*(%d+%%?)%s*,%s*(%d+%%?)%s*,%s*(%d+%%?)%s*%)()") if not match_end then return end r = percent_or_hex(r) if not r then return end g = percent_or_hex(g) if not g then return end b = percent_or_hex(b) if not b then return end - local rgb_hex = ("%02x%02x%02x"):format(r,g,b) - if #rgb_hex ~= 6 then return end + local rgb_hex = tohex(bor(lshift(r, 16), lshift(g, 8), b), 6) return match_end - 1, rgb_hex end function css_fn.hsl(line, i) - if #line < i + css_hsl_fn_minimum_length then return end + if #line < i + CSS_HSL_FN_MINIMUM_LENGTH then return end local h, s, l, match_end = line:sub(i):match("^hsl%(%s*(%d+)%s*,%s*(%d+)%%%s*,%s*(%d+)%%%s*%)()") if not match_end then return end h = tonumber(h) if h > 360 then return end @@ -229,24 +258,22 @@ do l = tonumber(l) if l > 100 then return end local r, g, b = hsl_to_rgb(h/360, s/100, l/100) if r == nil or g == nil or b == nil then return end - local rgb_hex = ("%02x%02x%02x"):format(floor(r), floor(g), floor(b)) - if #rgb_hex ~= 6 then return end + local rgb_hex = tohex(bor(lshift(floor(r), 16), lshift(floor(g), 8), floor(b)), 6) return match_end - 1, rgb_hex end function css_fn.rgba(line, i) - if #line < i + css_rgba_fn_minimum_length then return end + if #line < i + CSS_RGBA_FN_MINIMUM_LENGTH then return end local r, g, b, a, match_end = line:sub(i):match("^rgba%(%s*(%d+%%?)%s*,%s*(%d+%%?)%s*,%s*(%d+%%?)%s*,%s*([.%d]+)%s*%)()") if not match_end then return end a = tonumber(a) if not a or a > 1 then return end r = percent_or_hex(r) if not r then return end g = percent_or_hex(g) if not g then return end b = percent_or_hex(b) if not b then return end - local rgb_hex = ("%02x%02x%02x"):format(floor(r*a), floor(g*a), floor(b*a)) - if #rgb_hex ~= 6 then return end + local rgb_hex = tohex(bor(lshift(floor(r*a), 16), lshift(floor(g*a), 8), floor(b*a)), 6) return match_end - 1, rgb_hex end function css_fn.hsla(line, i) - if #line < i + css_hsla_fn_minimum_length then return end + if #line < i + CSS_HSLA_FN_MINIMUM_LENGTH then return end local h, s, l, a, match_end = line:sub(i):match("^hsla%(%s*(%d+)%s*,%s*(%d+)%%%s*,%s*(%d+)%%%s*,%s*([.%d]+)%s*%)()") if not match_end then return end a = tonumber(a) if not a or a > 1 then return end @@ -255,8 +282,7 @@ do l = tonumber(l) if l > 100 then return end local r, g, b = hsl_to_rgb(h/360, s/100, l/100) if r == nil or g == nil or b == nil then return end - local rgb_hex = ("%02x%02x%02x"):format(floor(r*a), floor(g*a), floor(b*a)) - if #rgb_hex ~= 6 then return end + local rgb_hex = tohex(bor(lshift(floor(r*a), 16), lshift(floor(g*a), 8), floor(b*a)), 6) return match_end - 1, rgb_hex end end @@ -353,20 +379,22 @@ end local MATCHER_CACHE = {} local function make_matcher(options) - local enable_names = options.css or options.names - local enable_RGB = options.css or options.RGB - local enable_RRGGBB = options.css or options.RRGGBB - local enable_RRGGBBAA = options.css or options.RRGGBBAA - local enable_rgb = options.css or options.css_fns or options.rgb_fn - local enable_hsl = options.css or options.css_fns or options.hsl_fn + local enable_names = options.css or options.names + local enable_RGB = options.css or options.RGB + local enable_RRGGBB = options.css or options.RRGGBB + local enable_RRGGBBAA = options.css or options.RRGGBBAA + local enable_rgb = options.css or options.css_fns or options.rgb_fn + local enable_hsl = options.css or options.css_fns or options.hsl_fn + local enable_lowercase = options.css or options.lowercase local matcher_key = bor( - lshift(enable_names and 1 or 0, 0), - lshift(enable_RGB and 1 or 0, 1), - lshift(enable_RRGGBB and 1 or 0, 2), - lshift(enable_RRGGBBAA and 1 or 0, 3), - lshift(enable_rgb and 1 or 0, 4), - lshift(enable_hsl and 1 or 0, 5)) + lshift(enable_names and 1 or 0, 0), + lshift(enable_RGB and 1 or 0, 1), + lshift(enable_RRGGBB and 1 or 0, 2), + lshift(enable_RRGGBBAA and 1 or 0, 3), + lshift(enable_rgb and 1 or 0, 4), + lshift(enable_hsl and 1 or 0, 5), + lshift(enable_lowercase and 1 or 0, 6)) if matcher_key == 0 then return end @@ -377,15 +405,17 @@ local function make_matcher(options) local loop_matchers = {} if enable_names then - table.insert(loop_matchers, name_parser) + table.insert(loop_matchers, function(line, i) + return color_name_parser(line, i, enable_lowercase) + end) end do local valid_lengths = {[3] = enable_RGB, [6] = enable_RRGGBB, [8] = enable_RRGGBBAA} local minlen, maxlen for k, v in pairs(valid_lengths) do if v then - minlen = math.min(k, minlen or 99) - maxlen = math.max(k, maxlen or 0) + minlen = minlen and min(k, minlen) or k + maxlen = maxlen and max(k, maxlen) or k end end if minlen then @@ -535,19 +565,25 @@ end -- @param[opt={'*'}] filetypes A table/array of filetypes to selectively enable and/or customize. By default, enables all filetypes. -- @tparam[opt] {[string]=string} default_options Default options to apply for the filetypes enable. -- @usage require'colorizer'.setup() -local function setup(filetypes, default_options) +local function setup(filetypes, user_default_options, global_configuration) if not nvim.o.termguicolors then nvim.err_writeln("&termguicolors must be set") return end - initialize_trie() FILETYPE_OPTIONS = {} SETUP_SETTINGS = { exclusions = {}; - default_options = merge(DEFAULT_OPTIONS, default_options or {}); + default_options = merge(DEFAULT_OPTIONS, user_default_options or {}); + } + global_configuration = global_configuration or { + lowercase = true; + on_enter = false; } - -- This is just in case I accidentally reference the wrong thing here. - default_options = SETUP_SETTINGS.default_options + if type(global_configuration.lowercase) == 'boolean' then + COLOR_NAME_SETTINGS.lowercase = global_configuration.lowercase + end + -- Initialize this AFTER setting COLOR_NAME_SETTINGS + initialize_trie() function COLORIZER_SETUP_HOOK() local filetype = nvim.bo.filetype if SETUP_SETTINGS.exclusions[filetype] then @@ -558,7 +594,9 @@ local function setup(filetypes, default_options) end nvim.ex.augroup("ColorizerSetup") nvim.ex.autocmd_() - -- nvim.ex.autocmd("VimEnter * lua COLORIZER_SETUP_HOOK()") + if global_configuration.on_enter then + nvim.ex.autocmd("VimEnter * lua COLORIZER_SETUP_HOOK()") + end if not filetypes then nvim.ex.autocmd("FileType * lua COLORIZER_SETUP_HOOK()") else diff --git a/lua/trie.lua b/lua/trie.lua index 9d84487..f2369d6 100644 --- a/lua/trie.lua +++ b/lua/trie.lua @@ -14,11 +14,6 @@ -- You should have received a copy of the GNU General Public License -- along with this program. If not, see . local ffi = require 'ffi' -local bit = require 'bit' - -local bnot = bit.bnot -local band, bor, bxor = bit.band, bit.bor, bit.bxor -local lshift, rshift, rol = bit.lshift, bit.rshift, bit.rol ffi.cdef [[ struct Trie { @@ -33,58 +28,27 @@ local Trie_t = ffi.typeof('struct Trie') local Trie_ptr_t = ffi.typeof('$ *', Trie_t) local Trie_size = ffi.sizeof(Trie_t) -local function byte_to_index(b) - -- 0-9 starts at string.byte('0') == 0x30 == 48 == 0b0011_0000 - -- A-Z starts at string.byte('A') == 0x41 == 65 == 0b0100_0001 - -- a-z starts at string.byte('a') == 0x61 == 97 == 0b0110_0001 - - -- This works for mapping characters to - -- 0-9 A-Z a-z in that order - -- Letters have bit 0x40 set, so we use that as an indicator for - -- an additional offset from the space of the digits, and then - -- add the 10 allocated for the range of digits. - -- Then, within that indicator for letters, we subtract another - -- (65 - 97) which is the difference between lower and upper case - -- and add back another 26 to allocate for the range of uppercase - -- letters. - -- return b - 0x30 - -- + rshift(b, 6) * ( - -- 0x30 - 0x41 - -- + 10 - -- + band(1, rshift(b, 5)) * ( - -- 0x61 - 0x41 - -- + 26 - -- )) - return b - 0x30 - rshift(b, 6) * (7 + band(1, rshift(b, 5)) * 6) -end - -local function insensitive_byte_to_index(b) - -- return b - 0x30 - -- + rshift(b, 6) * ( - -- 0x30 - 0x61 - -- + 10 - -- ) - b = bor(b, 0x20) - return b - 0x30 - rshift(b, 6) * 39 -end - -local function verify_byte_to_index() - local chars = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' - for i = 1, #chars do - local c = chars:sub(i,i) - local index = byte_to_index(string.byte(c)) - assert((i-1) == index, vim.inspect{index=index,c=c}) - end -end - local function trie_create() local ptr = ffi.C.malloc(Trie_size) ffi.fill(ptr, Trie_size) return ffi.cast(Trie_ptr_t, ptr) end +local function trie_destroy(trie) + if trie == nil then + return + end + for i = 0, 61 do + local child = trie.character[i] + if child ~= nil then + trie_destroy(child) + end + end + ffi.C.free(trie) +end + local INDEX_LOOKUP_TABLE = ffi.new 'uint8_t[256]' -local CHAR_LOOKUP_TABLE = ffi.new('char[62]', '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz') +local CHAR_LOOKUP_TABLE = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' do local b = string.byte for i = 0, 255 do @@ -117,10 +81,10 @@ local function trie_insert(trie, value) return node, trie end -local function trie_search(trie, value) +local function trie_search(trie, value, start) if trie == nil then return false end local node = trie - for i = 1, #value do + for i = (start or 1), #value do local index = INDEX_LOOKUP_TABLE[value:byte(i)] if index == 255 then return @@ -134,12 +98,15 @@ local function trie_search(trie, value) return node.is_leaf end -local function trie_longest_prefix(trie, value) +local function trie_longest_prefix(trie, value, start) if trie == nil then return false end + -- insensitive = insensitive and 0x20 or 0 + start = start or 1 local node = trie local last_i = nil - for i = 1, #value do + for i = start, #value do local index = INDEX_LOOKUP_TABLE[value:byte(i)] +-- local index = INDEX_LOOKUP_TABLE[bor(insensitive, value:byte(i))] if index == 255 then break end @@ -153,7 +120,12 @@ local function trie_longest_prefix(trie, value) node = child end if last_i then - return value:sub(1, last_i) + -- Avoid a copy if the whole string is a match. + if start == 1 and last_i == #value then + return value + else + return value:sub(start, last_i) + end end end @@ -168,7 +140,7 @@ end local function index_to_char(index) if index < 0 or index > 61 then return end - return CHAR_LOOKUP_TABLE[index] + return CHAR_LOOKUP_TABLE:sub(index+1, index+1) end local function trie_as_table(trie) @@ -190,7 +162,7 @@ local function trie_as_table(trie) } end -local function print_trie_table(s) +local function print_trie_table(s, thicc) local mark if not s then return {'nil'} @@ -209,44 +181,48 @@ local function print_trie_table(s) end local lines = {} for _, child in ipairs(s.children) do - local child_lines = print_trie_table(child) + local child_lines = print_trie_table(child, thicc) for _, child_line in ipairs(child_lines) do table.insert(lines, child_line) end end - for i, v in ipairs(lines) do - if v:match("^[%w%d]") then + local child_count = 0 + for i, line in ipairs(lines) do + local line_parts = {} + -- if line[1] and line[1]:match("^%w") then + if line:match("^%w") then + child_count = child_count + 1 if i == 1 then - lines[i] = mark.."─"..v - elseif i == #lines then - lines[i] = "└──"..v + line_parts = {mark} + elseif i == #lines or child_count == #s.children then + line_parts = {"└─"} else - lines[i] = "├──"..v + line_parts = {"├─"} end + if thicc then table.insert(line_parts, "─") end else if i == 1 then - lines[i] = mark.."─"..v - elseif #s.children > 1 then - lines[i] = "│ "..v + line_parts = {mark} + if thicc then table.insert(line_parts, "─") end + elseif #s.children > 1 and child_count ~= #s.children then + line_parts = {thicc and "│ " or "│ "} else - lines[i] = " "..v + line_parts = {thicc and " " or " "} end end + table.insert(line_parts, line) + -- lines[i] = vim.tbl_flatten(line_parts) + lines[i] = table.concat(line_parts) end return lines end -local function trie_destroy(trie) +local function trie_to_string(trie, thicc) if trie == nil then - return + return 'nil' end - for i = 0, 61 do - local child = trie.character[i] - if child ~= nil then - trie_destroy(child) - end - end - ffi.C.free(trie) + local as_table = trie_as_table(trie) + return table.concat(print_trie_table(as_table, thicc), '\n') end local Trie_mt = { @@ -262,13 +238,9 @@ local Trie_mt = { search = trie_search; longest_prefix = trie_longest_prefix; extend = trie_extend; + to_string = trie_to_string; }; - __tostring = function(trie) - if trie == nil then - return 'nil' - end - return table.concat(print_trie_table(trie_as_table(trie)), '\n') - end; + __tostring = trie_to_string; __gc = trie_destroy; } diff --git a/test/expectation.txt b/test/expectation.txt index 915d4de..7813fe8 100644 --- a/test/expectation.txt +++ b/test/expectation.txt @@ -8,9 +8,9 @@ require'colorizer'.attach_to_buffer(0, {css=true}) #F0F #FF00FF #FFF00F8F - #F0F 1 - #FF00FF 1 - #FFF00F8F 1 + #F0F #F00 + #FF00FF #F00 + #FFF00F8F #F00 Blue Gray LightBlue Gray100 White White #def @@ -22,6 +22,7 @@ hsl(300,50%,50%) hsla(300,50%,50%,0.5) hsla(300,50%,50%,1.0000000000000001) hsla(360,50%,50%,1.0000000000000001) +blue gray lightblue gray100 white gold blue ]] --[[ FAIL diff --git a/test/print-trie.lua b/test/print-trie.lua new file mode 100644 index 0000000..09059e8 --- /dev/null +++ b/test/print-trie.lua @@ -0,0 +1,36 @@ +-- TODO this is kinda shitty +local function dirname(str,sep) + sep=sep or'/' + return str:match("(.*"..sep..")") +end + +local script_dir = dirname(arg[0]) +package.path = script_dir.."/../lua/?.lua;"..package.path + +local Trie = require 'trie' +local nvim = require 'nvim' + +local tohex = bit.tohex +local min, max = math.min, math.max + +local COLOR_NAME_SETTINGS = { + lowercase = false; + strip_digits = true; +} +COLOR_MAP = {} +COLOR_TRIE = Trie() +for k, v in pairs(nvim.get_color_map()) do + if not (COLOR_NAME_SETTINGS.strip_digits and k:match("%d+$")) then + COLOR_NAME_MINLEN = COLOR_NAME_MINLEN and min(#k, COLOR_NAME_MINLEN) or #k + COLOR_NAME_MAXLEN = COLOR_NAME_MAXLEN and max(#k, COLOR_NAME_MAXLEN) or #k + COLOR_MAP[k] = tohex(v, 6) + COLOR_TRIE:insert(k) + if COLOR_NAME_SETTINGS.lowercase then + local lowercase = k:lower() + COLOR_MAP[lowercase] = tohex(v, 6) + COLOR_TRIE:insert(lowercase) + end + end +end + +print(COLOR_TRIE)