Skip to content

Commit

Permalink
Performance improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Jun 28, 2023
1 parent eb401d7 commit a4f79c6
Showing 1 changed file with 12 additions and 19 deletions.
31 changes: 12 additions & 19 deletions bpe.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func (bpe *coreBPE) Encode(text string, allowedSpecial map[string]any) ([]uint,
retIDs := []uint{}
retTokens := []string{}

textRunes := []rune(text)
textLength := len(text)

start := 0

Expand All @@ -94,11 +94,11 @@ func (bpe *coreBPE) Encode(text string, allowedSpecial map[string]any) ([]uint,
startFind := start

for {
temp := cutRunes(textRunes, startFind, len(textRunes))
temp := cutText(text, startFind, textLength)
nextSpecial = findRegex2StringIndex(temp, specialRegex)

if nextSpecial != nil {
token := cutRunes(textRunes, startFind+nextSpecial[0], startFind+nextSpecial[1])
token := cutText(text, startFind+nextSpecial[0], startFind+nextSpecial[1])
if _, ok := allowedSpecial[token]; ok {
break
}
Expand All @@ -109,13 +109,13 @@ func (bpe *coreBPE) Encode(text string, allowedSpecial map[string]any) ([]uint,
}
}

end := len([]rune(text))
end := textLength
if nextSpecial != nil {
end = start + nextSpecial[0]
}

for _, mat := range findRegex2AllStringMatchIndex(cutRunes(textRunes, start, end), regex) {
piece := cutRunes(textRunes, start+mat[0], start+mat[1])
for _, mat := range findRegex2AllStringMatchIndex(cutText(text, start, end), regex) {
piece := cutText(text, start+mat[0], start+mat[1])
if id, ok := bpe.encoder[piece]; ok {
retIDs = append(retIDs, id)
retTokens = append(retTokens, piece)
Expand All @@ -129,7 +129,7 @@ func (bpe *coreBPE) Encode(text string, allowedSpecial map[string]any) ([]uint,
}

if nextSpecial != nil {
temp := cutRunes(textRunes, start+nextSpecial[0], start+nextSpecial[1])
temp := cutText(text, start+nextSpecial[0], start+nextSpecial[1])
id := bpe.specialTokensEncoder[temp]
retIDs = append(retIDs, id)
retTokens = append(retTokens, temp)
Expand All @@ -149,10 +149,8 @@ func (bpe *coreBPE) EncodeOrdinary(text string) ([]uint, []string) {
retIDs := []uint{}
retTokens := []string{}

textRunes := []rune(text)

for _, mat := range findRegex2AllStringMatchIndex(text, bpe.tlRegex) {
piece := cutRunes(textRunes, mat[0], mat[1])
piece := cutText(text, mat[0], mat[1])
if id, ok := bpe.encoder[piece]; ok {
retIDs = append(retIDs, id)
retTokens = append(retTokens, piece)
Expand Down Expand Up @@ -302,19 +300,14 @@ func findRegex2AllStringMatchIndex(text string, reg *regexp2.Regexp) [][]int {
return matches
}

// cutRunes extracts a substring from the given rune slice based on the start and end indices.
// It returns the extracted substring as a string.
// The function takes the input rune slice, start index, and end index as parameters.
// If the start index is negative, it is set to 0.
// If the end index exceeds the length of the rune slice, it is set to the length of the rune slice.
func cutRunes(runes []rune, start, end int) string {
func cutText(text string, start, end int) string {
if start < 0 {
start = 0
}

if end > len(runes) {
end = len(runes)
if end > len(text) {
end = len(text)
}

return string(runes[start:end])
return text[start:end]
}

0 comments on commit a4f79c6

Please sign in to comment.