|
@@ -19,14 +19,27 @@ def convert_audio(data: bytes) -> bytes:
|
|
|
|
|
|
return out
|
|
|
|
|
|
+MODELS = ["tiny.en", "tiny", "base.en", "base", "small.en", "small", "medium.en", "medium", "large"]
|
|
|
+
|
|
|
class ASR():
|
|
|
def __init__(self, model = "tiny"):
|
|
|
+ if model not in MODELS:
|
|
|
+ raise ValueError(f"Invalid model: {model}. Must be one of {MODELS}")
|
|
|
self.model = model
|
|
|
+ os.mkdir("/data/models")
|
|
|
+ self.model_path = f"/data/models/ggml-{model}.bin"
|
|
|
+ self.model_url = f"https://ggml.ggerganov.com/ggml-model-whisper-{self.model}.bin"
|
|
|
+
|
|
|
+ def load_model(self):
|
|
|
+ if not os.path.exists(self.model_path):
|
|
|
+ print("Downloading model...")
|
|
|
+ subprocess.run(["wget", self.model_url, "-O", self.model_path], check=True)
|
|
|
+ print("Done.")
|
|
|
|
|
|
def transcribe(self, audio: bytes) -> str:
|
|
|
convert_audio(audio)
|
|
|
stdout, stderr = subprocess.Popen(
|
|
|
- ["./main", "-m", f"models/ggml-{self.model}.bin", "-f", "audio.wav", "--no_timestamps"],
|
|
|
+ ["./main", "-m", self.model_path, "-f", "audio.wav", "--no_timestamps"],
|
|
|
stdout=subprocess.PIPE
|
|
|
).communicate()
|
|
|
|