diff --git a/bpe.go b/bpe.go index 7ab3cf3..73c6add 100644 --- a/bpe.go +++ b/bpe.go @@ -84,7 +84,7 @@ func (bpe *coreBPE) Encode(text string, allowedSpecial map[string]any) ([]uint, retIDs := []uint{} retTokens := []string{} - textRunes := []rune(text) + textLength := len(text) start := 0 @@ -94,11 +94,11 @@ func (bpe *coreBPE) Encode(text string, allowedSpecial map[string]any) ([]uint, startFind := start for { - temp := cutRunes(textRunes, startFind, len(textRunes)) + temp := cutText(text, startFind, textLength) nextSpecial = findRegex2StringIndex(temp, specialRegex) if nextSpecial != nil { - token := cutRunes(textRunes, startFind+nextSpecial[0], startFind+nextSpecial[1]) + token := cutText(text, startFind+nextSpecial[0], startFind+nextSpecial[1]) if _, ok := allowedSpecial[token]; ok { break } @@ -109,13 +109,13 @@ func (bpe *coreBPE) Encode(text string, allowedSpecial map[string]any) ([]uint, } } - end := len([]rune(text)) + end := textLength if nextSpecial != nil { end = start + nextSpecial[0] } - for _, mat := range findRegex2AllStringMatchIndex(cutRunes(textRunes, start, end), regex) { - piece := cutRunes(textRunes, start+mat[0], start+mat[1]) + for _, mat := range findRegex2AllStringMatchIndex(cutText(text, start, end), regex) { + piece := cutText(text, start+mat[0], start+mat[1]) if id, ok := bpe.encoder[piece]; ok { retIDs = append(retIDs, id) retTokens = append(retTokens, piece) @@ -129,7 +129,7 @@ func (bpe *coreBPE) Encode(text string, allowedSpecial map[string]any) ([]uint, } if nextSpecial != nil { - temp := cutRunes(textRunes, start+nextSpecial[0], start+nextSpecial[1]) + temp := cutText(text, start+nextSpecial[0], start+nextSpecial[1]) id := bpe.specialTokensEncoder[temp] retIDs = append(retIDs, id) retTokens = append(retTokens, temp) @@ -149,10 +149,8 @@ func (bpe *coreBPE) EncodeOrdinary(text string) ([]uint, []string) { retIDs := []uint{} retTokens := []string{} - textRunes := []rune(text) - for _, mat := range findRegex2AllStringMatchIndex(text, bpe.tlRegex) { - piece := cutRunes(textRunes, mat[0], mat[1]) + piece := cutText(text, mat[0], mat[1]) if id, ok := bpe.encoder[piece]; ok { retIDs = append(retIDs, id) retTokens = append(retTokens, piece) @@ -302,19 +300,14 @@ func findRegex2AllStringMatchIndex(text string, reg *regexp2.Regexp) [][]int { return matches } -// cutRunes extracts a substring from the given rune slice based on the start and end indices. -// It returns the extracted substring as a string. -// The function takes the input rune slice, start index, and end index as parameters. -// If the start index is negative, it is set to 0. -// If the end index exceeds the length of the rune slice, it is set to the length of the rune slice. -func cutRunes(runes []rune, start, end int) string { +func cutText(text string, start, end int) string { if start < 0 { start = 0 } - if end > len(runes) { - end = len(runes) + if end > len(text) { + end = len(text) } - return string(runes[start:end]) + return text[start:end] }