瀏覽代碼

load models on demand

subDesTagesMitExtraKaese 2 年之前
父節點
當前提交
aedddaf812
共有 4 個文件被更改,包括 18 次插入5 次删除
  1. 2 4
      Dockerfile
  2. 1 0
      docker-compose.yml.sample
  3. 1 0
      main.py
  4. 14 1
      speech_recognition.py

+ 2 - 4
Dockerfile

@@ -10,22 +10,20 @@ RUN apt-get update && apt-get install --no-install-recommends -y \
 ADD whisper.cpp/ /build/
 RUN gcc -pthread -O3 -march=native -c ggml.c && \
     g++ -pthread -O3 -std=c++11 -c main.cpp && \
-    g++ -pthread -o main ggml.o main.o && \
-    ./download-ggml-model.sh tiny
+    g++ -pthread -o main ggml.o main.o
 
 # main image
 FROM alpine
 WORKDIR /app/
 
 # Install dependencies
-RUN apk add ffmpeg py3-olm py3-matrix-nio py3-pip py3-pillow gcompat
+RUN apk add ffmpeg py3-olm py3-matrix-nio py3-pip py3-pillow gcompat wget
 
 ADD requirements.txt .
 
 RUN pip install -r requirements.txt
 
 COPY --from=builder /build/main /app/
-COPY --from=builder /build/models/ /app/models/
 
 VOLUME /data/
 

+ 1 - 0
docker-compose.yml.sample

@@ -9,4 +9,5 @@ services:
       - "HOMESERVER=https://matrix.example.com"
       - "USERNAME=@stt-bot:example.com"
       - "PASSWORD=<password>"
+      - "ASR_MODEL=tiny"
       

+ 1 - 0
main.py

@@ -49,4 +49,5 @@ async def on_audio_message(room, event):
         msgtype="m.notice")
 
 if __name__ == "__main__":
+  asr.load_model()
   bot.run()

+ 14 - 1
speech_recognition.py

@@ -19,14 +19,27 @@ def convert_audio(data: bytes) -> bytes:
 
   return out
 
+MODELS = ["tiny.en", "tiny", "base.en", "base", "small.en", "small", "medium.en", "medium", "large"]
+
 class ASR():
   def __init__(self, model = "tiny"):
+    if model not in MODELS:
+      raise ValueError(f"Invalid model: {model}. Must be one of {MODELS}")
     self.model = model
+    os.mkdir("/data/models")
+    self.model_path = f"/data/models/ggml-{model}.bin"
+    self.model_url = f"https://ggml.ggerganov.com/ggml-model-whisper-{self.model}.bin"
+
+  def load_model(self):
+    if not os.path.exists(self.model_path):
+      print("Downloading model...")
+      subprocess.run(["wget", self.model_url, "-O", self.model_path], check=True)
+      print("Done.")
 
   def transcribe(self, audio: bytes) -> str:
     convert_audio(audio)
     stdout, stderr = subprocess.Popen(
-        ["./main", "-m", f"models/ggml-{self.model}.bin", "-f", "audio.wav", "--no_timestamps"], 
+        ["./main", "-m", self.model_path, "-f", "audio.wav", "--no_timestamps"], 
         stdout=subprocess.PIPE
       ).communicate()