Browse Source

fix accessing preloaded model

subDesTagesMitExtraKaese 2 years ago
parent
commit
cd3e70dbbc
3 changed files with 14 additions and 14 deletions
  1. 4 4
      .github/workflows/docker.yml
  2. 1 1
      main.py
  3. 9 9
      speech_recognition.py

+ 4 - 4
.github/workflows/docker.yml

@@ -34,7 +34,9 @@ jobs:
           context: .
           context: .
           platforms: linux/amd64,linux/arm64
           platforms: linux/amd64,linux/arm64
           push: true
           push: true
-          tags: ftcaplan/matrix-stt-bot:latest
+          tags: ftcaplan/matrix-stt-bot:tiny
+          build-args: |
+            "PRELOAD_MODEL=tiny"
       -
       -
         name: Build and push
         name: Build and push
         uses: docker/build-push-action@v2
         uses: docker/build-push-action@v2
@@ -42,9 +44,7 @@ jobs:
           context: .
           context: .
           platforms: linux/amd64,linux/arm64
           platforms: linux/amd64,linux/arm64
           push: true
           push: true
-          tags: ftcaplan/matrix-stt-bot:tiny
-          build-args: |
-            "PRELOAD_MODEL=tiny"
+          tags: ftcaplan/matrix-stt-bot:latest
       -
       -
         name: Build and push
         name: Build and push
         uses: docker/build-push-action@v2
         uses: docker/build-push-action@v2

+ 1 - 1
main.py

@@ -23,7 +23,7 @@ config.ignore_unverified_devices = True
 config.store_path = '/data/crypto_store/'
 config.store_path = '/data/crypto_store/'
 bot = botlib.Bot(creds, config)
 bot = botlib.Bot(creds, config)
 
 
-asr = ASR(os.getenv('ASR_MODEL', 'tiny'), os.getenv('ASR_LANGUAGE', 'en'))
+asr = ASR(os.getenv('ASR_MODEL', os.getenv('PRELOAD_MODEL', 'tiny')), os.getenv('ASR_LANGUAGE', 'en'))
 
 
 @bot.listener.on_custom_event(nio.RoomMessage)
 @bot.listener.on_custom_event(nio.RoomMessage)
 async def on_message(room, event):
 async def on_message(room, event):

+ 9 - 9
speech_recognition.py

@@ -3,7 +3,6 @@ import ffmpeg
 import asyncio
 import asyncio
 import subprocess
 import subprocess
 import os
 import os
-import shutil
 
 
 SAMPLE_RATE = 16000
 SAMPLE_RATE = 16000
 
 
@@ -37,19 +36,20 @@ class ASR():
     self.model = model
     self.model = model
     self.language = language
     self.language = language
 
 
-    if not os.path.exists("/data/models"):
-      os.mkdir("/data/models")
-    self.model_path = f"/data/models/ggml-{model}.bin"
+    if os.path.exists(f"/app/ggml-model-whisper-{model}.bin"):
+      self.model_path = f"/app/ggml-model-whisper-{model}.bin"
+    else:
+      self.model_path = f"/data/models/ggml-{model}.bin"
+      if not os.path.exists("/data/models"):
+        os.mkdir("/data/models")
+        
     self.model_url = f"https://ggml.ggerganov.com/ggml-model-whisper-{self.model}.bin"
     self.model_url = f"https://ggml.ggerganov.com/ggml-model-whisper-{self.model}.bin"
     self.lock = asyncio.Lock()
     self.lock = asyncio.Lock()
 
 
   def load_model(self):
   def load_model(self):
     if not os.path.exists(self.model_path) or os.path.getsize(self.model_path) == 0:
     if not os.path.exists(self.model_path) or os.path.getsize(self.model_path) == 0:
-      print("Fetching model...")
-      if os.path.exists(f"ggml-model-whisper-{self.model}.bin"):
-        shutil.copy(f"ggml-model-whisper-{self.model}.bin", self.model_path)
-      else:
-        subprocess.run(["wget", self.model_url, "-O", self.model_path], check=True)
+      print("Downloading model...")
+      subprocess.run(["wget", self.model_url, "-O", self.model_path], check=True)
       print("Done.")
       print("Done.")
 
 
   async def transcribe(self, audio: bytes) -> str:
   async def transcribe(self, audio: bytes) -> str: