diff --git a/utils/vectors.go b/utils/vectors.go index 75030d4..265773a 100644 --- a/utils/vectors.go +++ b/utils/vectors.go @@ -3,6 +3,7 @@ package util import ( "fmt" "math" + "golang.org/x/exp/slices" // like in tokenClassification.go ) @@ -78,8 +79,11 @@ func Norm(v []float32, p int) float64 { // Normalize single vector according to: https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html func Normalize(embedding []float32, p int) []float32 { - const eps = 1e-12 - normalizeDenominator := float32(max(Norm(embedding, p), eps)) + var normalizeDenominator float32 = 1e-12 + embeddingNorm := float32(Norm(embedding, p)) + if embeddingNorm > normalizeDenominator { + normalizeDenominator = embeddingNorm + } for i, v := range embedding { embedding[i] = v / normalizeDenominator }