speechlib/speechlib/core_analysis.py

146 lines
5.0 KiB
Python

import os
from pyannote.audio import Pipeline
import time
from .wav_segmenter import (wav_file_segmentation)
import torch, torchaudio
from .speaker_recognition import (speaker_recognition)
from .write_log_file import (write_log_file)
from .re_encode import (re_encode)
from .convert_to_mono import (convert_to_mono)
from .convert_to_wav import (convert_to_wav)
import subprocess
# by default use google speech-to-text API
# if False, then use whisper finetuned version for sinhala
def core_analysis(file_name, voices_folder, log_folder, language, modelSize, ACCESS_TOKEN, model_type, quantization=False, custom_model_path=None, hf_model_id=None, aai_api_key=None):
# <-------------------PreProcessing file-------------------------->
# convert compressed wav
# if file_name.lower().endswith(".wav"):
# subprocess.call(['ffmpeg','i',file_name,file_name.split(".")[0]+'.wav'])
# check if file is in wav format, if not convert to wav
file_name = convert_to_wav(file_name)
# convert file to mono
convert_to_mono(file_name)
# re-encode file to 16-bit PCM encoding
re_encode(file_name)
# <--------------------running analysis--------------------------->
speaker_tags = []
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",
use_auth_token=ACCESS_TOKEN)
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
else:
device = torch.device("cpu")
pipeline.to(device)
waveform, sample_rate = torchaudio.load(file_name)
start_time = int(time.time())
print("running diarization...")
diarization = pipeline({"waveform": waveform, "sample_rate": sample_rate}, min_speakers=0, max_speakers=10)
end_time = int(time.time())
elapsed_time = int(end_time - start_time)
print(f"diarization done. Time taken: {elapsed_time} seconds.")
speakers = {}
common = []
# create a dictionary of SPEAKER_XX to real name mappings
speaker_map = {}
for turn, _, speaker in diarization.itertracks(yield_label=True):
start = round(turn.start, 1)
end = round(turn.end, 1)
common.append([start, end, speaker])
# find different speakers
if speaker not in speaker_tags:
speaker_tags.append(speaker)
speaker_map[speaker] = speaker
speakers[speaker] = []
speakers[speaker].append([start, end, speaker])
if voices_folder != None and voices_folder != "":
identified = []
start_time = int(time.time())
print("running speaker recognition...")
for spk_tag, spk_segments in speakers.items():
spk_name = speaker_recognition(file_name, voices_folder, spk_segments, identified)
spk = spk_name
identified.append(spk)
speaker_map[spk_tag] = spk
end_time = int(time.time())
elapsed_time = int(end_time - start_time)
print(f"speaker recognition done. Time taken: {elapsed_time} seconds.")
keys_to_remove = []
merged = []
# merging same speakers
for spk_tag1, spk_segments1 in speakers.items():
for spk_tag2, spk_segments2 in speakers.items():
if spk_tag1 not in merged and spk_tag2 not in merged and spk_tag1 != spk_tag2 and speaker_map[spk_tag1] == speaker_map[spk_tag2]:
for segment in spk_segments2:
speakers[spk_tag1].append(segment)
merged.append(spk_tag1)
merged.append(spk_tag2)
keys_to_remove.append(spk_tag2)
# fixing the speaker names in common
for segment in common:
speaker = segment[2]
segment[2] = speaker_map[speaker]
for key in keys_to_remove:
del speakers[key]
del speaker_map[key]
# transcribing the texts differently according to speaker
start_time = int(time.time())
print("running transcription...")
for spk_tag, spk_segments in speakers.items():
spk = speaker_map[spk_tag]
segment_out = wav_file_segmentation(file_name, spk_segments, language, modelSize, model_type, quantization, custom_model_path, hf_model_id, aai_api_key)
speakers[spk_tag] = segment_out
end_time = int(time.time())
elapsed_time = int(end_time - start_time)
print(f"transcription done. Time taken: {elapsed_time} seconds.")
common_segments = []
for item in common:
speaker = item[2]
start = item[0]
end = item[1]
for spk_tag, spk_segments in speakers.items():
if speaker == speaker_map[spk_tag]:
for segment in spk_segments:
if start == segment[0] and end == segment[1]:
common_segments.append([start, end, segment[2], speaker])
# writing log file
write_log_file(common_segments, log_folder, file_name, language)
## TODO cleaning segments and temp folder
return common_segments