morkovka1337's picture
Update app.py
12a5951
raw
history blame contribute delete
No virus
3.51 kB
import gradio as gr
import sys
import logging
from huggingsound import SpeechRecognitionModel
from transformers import pipeline, AutoModelForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM
# COPYPASTED FROM: https://huggingface.co/spaces/jonatasgrosman/asr/blob/main/app.py
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
model_ID = "jonatasgrosman/wav2vec2-large-xlsr-53-russian"
CACHED_MODEL = {"rus": AutoModelForCTC.from_pretrained(model_ID)}
def run(input_file, history, model_size="300M"):
language = "Russian"
decoding_type = "LM"
logger.info(f"Running ASR {language}-{model_size}-{decoding_type} for {input_file}")
# history = history or []
# the history seems to be not by session anymore, so I'll deactivate this for now
history = []
model_instance = CACHED_MODEL.get("rus")
if decoding_type == "LM":
processor = Wav2Vec2ProcessorWithLM.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-russian")
asr = pipeline("automatic-speech-recognition", model=model_instance, tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor, decoder=processor.decoder)
else:
processor = Wav2Vec2Processor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-russian")
asr = pipeline("automatic-speech-recognition", model=model_instance, tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor, decoder=None)
transcription = asr(input_file.name, chunk_length_s=5, stride_length_s=1)["text"]
logger.info(f"Transcription for {language}-{model_size}-{decoding_type} for {input_file}: {transcription}")
history.append({
"model_id": model_ID,
"language": language,
"model_size": model_size,
"decoding_type": decoding_type,
"transcription": transcription,
"error_message": None
})
html_output = "<div class='result'>"
for item in history:
if item["error_message"] is not None:
html_output += f"<div class='result_item result_item_error'>{item['error_message']}</div>"
else:
url_suffix = " + LM" if item["decoding_type"] == "LM" else ""
html_output += "<div class='result_item result_item_success'>"
html_output += f'<strong><a target="_blank" href="https://huggingface.co/{item["model_id"]}">{item["model_id"]}{url_suffix}</a></strong><br/><br/>'
html_output += f'{item["transcription"]}<br/>'
html_output += "</div>"
html_output += "</div>"
return html_output, history
gr.Interface(
run,
inputs=[
gr.inputs.Audio(source="microphone", type="file", label="Record something..."),
"state"
],
outputs=[
gr.outputs.HTML(label="Outputs"),
"state"
],
title="Automatic Speech Recognition",
description="",
css="""
.result {display:flex;flex-direction:column}
.result_item {padding:15px;margin-bottom:8px;border-radius:15px;width:100%}
.result_item_success {background-color:mediumaquamarine;color:white;align-self:start}
.result_item_error {background-color:#ff7070;color:white;align-self:start}
""",
allow_screenshot=False,
allow_flagging="never",
theme="grass"
).launch(enable_queue=True)