Skip to content

Commit

Permalink
added support for multi label model results
Browse files Browse the repository at this point in the history
  • Loading branch information
gferraro committed Feb 14, 2023
1 parent 3d509e5 commit dde0061
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 31 deletions.
1 change: 0 additions & 1 deletion Melt/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@


def species_identify(file_name, metadata_name, models, bird_model):

labels = identify_species(file_name, metadata_name, models)
other_labels = classify(file_name, bird_model)
other_labels = [other for other in other_labels if other["species"] != "human"]
Expand Down
69 changes: 39 additions & 30 deletions Melt/identify_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,42 +14,50 @@
)


def load_recording(file, resample=48000):
frames, sr = librosa.load(str(file), sr=None)
if resample is not None and resample != sr:
frames = librosa.resample(frames, orig_sr=sr, target_sr=resample)
sr = resample
return frames, sr


def load_samples(path, segment_length, stride, hop_length=640):
frames, sr = librosa.load(path, sr=None)
frames, sr = load_recording(path)
mels = []
i = 0
n_fft = sr // 10
# hop_length = 640 # feature frame rate of 75

mel_all = librosa.feature.melspectrogram(
y=frames,
sr=sr,
n_fft=n_fft,
hop_length=hop_length,
fmin=50,
fmax=11000,
n_mels=80,
)
mel_all = librosa.power_to_db(mel_all, ref=np.max)
mel_sample_size = int(1 + segment_length * sr / hop_length)
jumps_per_stride = int(mel_sample_size / segment_length)

length = mel_all.shape[1]
sample_size = int(sr * segment_length)
jumps_per_stride = int(sr * stride)
length = len(frames) / sr
end = 0
mel_samples = []
i = 0
while end < length:
start = int(jumps_per_stride * (i * stride))
end = start + mel_sample_size
mel = mel_all[:, start:end].copy()
data = frames[i * jumps_per_stride : i * jumps_per_stride + sample_size]
if len(data) != sample_size:
sample = np.zeros((sample_size))
sample[: len(data)] = data
data = sample
end += stride
# /start = int(jumps_per_stride * (i * stride))
mel = librosa.feature.melspectrogram(
y=data,
sr=sr,
n_fft=n_fft,
hop_length=hop_length,
fmin=50,
fmax=11000,
n_mels=80,
)
mel = librosa.power_to_db(mel, ref=np.max)

# end = start + sample_size
mel_m = tf.reduce_mean(mel, axis=1)
mel_m = tf.expand_dims(mel_m, axis=1)
mel = mel - mel_m
if mel.shape[1] != 226:
# pad with zeros
empty = np.zeros(((80, 226)))
empty[:, : mel.shape[1]] = mel
mel = empty

mel_samples.append(mel)
i += 1
Expand All @@ -59,8 +67,11 @@ def load_samples(path, segment_length, stride, hop_length=640):
def load_model(model_path):
logging.debug("Loading %s", model_path)
model_path = Path(model_path)
model = tf.keras.models.load_model(model_path)
model.load_weights(model_path / "val_accuracy").expect_partial()
model = tf.keras.models.load_model(
str(model_path),
compile=False,
)
# model.load_weights(model_path / "val_binary_accuracy").expect_partial()
meta_file = model_path / "metadata.txt"
with open(meta_file, "r") as f:
meta = json.load(f)
Expand All @@ -73,11 +84,10 @@ def classify(file, model_file):
multi_label = meta.get("multi_label")
segment_length = meta.get("segment_length", 3)
segment_stride = meta.get("segment_stride", 1.5)
hop_length = meta.get("hop_length", 640)

segment_stride = meta.get("hop_length", 640)
samples, length = load_samples(file, segment_length, segment_stride, hop_length)
predictions = model.predict(samples, verbose=0)

tracks = []
start = 0
active_tracks = {}
Expand All @@ -86,14 +96,14 @@ def classify(file, model_file):
track_labels = []
if multi_label:
for i, p in enumerate(prediction):
if p > 0.7:
if p >= 0.7:
label = labels[i]
results.append((p, label))
track_labels.append(label)
else:
best_i = np.argmax(prediction)
best_p = prediction[best_i]
if best_p > 0.7:
if best_p >= 0.7:
label = labels[best_i]
results.append((best_p, label))
track_labels.append(label)
Expand Down Expand Up @@ -124,7 +134,6 @@ def classify(file, model_file):
# track = None

start += segment_stride

return [t.get_meta() for t in tracks]


Expand Down

0 comments on commit dde0061

Please sign in to comment.