Browse Source

allow loading external models

subDesTagesMitExtraKaese 1 month ago
parent
commit
377763ced2
3 changed files with 16 additions and 13 deletions
  1. 7 5
      README.md
  2. 1 1
      main.py
  3. 8 7
      speech_recognition.py

+ 7 - 5
README.md

@@ -27,14 +27,16 @@ services:
 ```
 
 ## Configuration
-The bot will download the model file on first run to reduce image size. Available models are `tiny.en`, `tiny`, `base.en`, `base`, `small.en`, `small`, `medium.en`, `medium`, and `large`. The default is `ASR_MODEL=tiny`.
 
-You can authenticate using tokens instead of a password by setting `LOGIN_TOKEN=<login-token>` or `ACCESS_TOKEN=<access-token>` instead of `PASSWORD=<password>`.
+- **ASR_MODEL**: You can choose a model by setting it with `ASR_MODEL`. 
+  
+  - Available models are for example `tiny.en`, `tiny`, `base`, `small`, `medium`, and `large-v3`. The full list is available on [Hugging Face](https://huggingface.co/ggerganov/whisper.cpp). 
+  
+  - The default is `ASR_MODEL=tiny`. The bot will download the model file on first run to reduce image size.
 
-- **ASR_MODEL**: You can choose a docker tag with the corresponding model pre downloaded or set it with `ASR_MODEL`. Available models are `tiny.en`, `tiny`, `base.en`, `base`, `small.en`, `small`, `medium.en`, `medium`, and `large`. The default is `ASR_MODEL=tiny`.
+  - You can load your own ggml models by providing them at the following path: `/data/models/ggml-$ASR_MODEL.bin`
 
-- **Authentication**:
-  - You can authenticate using tokens instead of a password:
+- **Authentication**: You can authenticate using tokens instead of a password:
     - Set `LOGIN_TOKEN=<login-token>` or `ACCESS_TOKEN=<access-token>` instead of `PASSWORD=<password>`.
 
 - **Allowlist**:

+ 1 - 1
main.py

@@ -30,7 +30,7 @@ if 'ALLOWLIST' in os.environ:
 
 bot = botlib.Bot(creds, config)
 
-asr = ASR(os.getenv('ASR_MODEL', os.getenv('PRELOAD_MODEL', 'tiny')), os.getenv('ASR_LANGUAGE', 'en'))
+asr = ASR(os.getenv('ASR_MODEL', 'tiny'), os.getenv('ASR_LANGUAGE', 'en'))
 
 @bot.listener.on_custom_event(nio.RoomMessage)
 async def on_message(room, event):

+ 8 - 7
speech_recognition.py

@@ -62,11 +62,6 @@ MODELS = [
 
 class ASR():
   def __init__(self, model = "tiny", language = "en"):
-    if model not in MODELS:
-      raise ValueError(f"Invalid model: {model}. Must be one of {MODELS}")
-    self.model = model
-    self.language = language
-
     if os.path.exists(f"/app/ggml-{model}.bin"):
       self.model_path = f"/app"
     else:
@@ -74,11 +69,17 @@ class ASR():
       if not os.path.exists(self.model_path):
         os.mkdir(self.model_path)
 
+    file_path = f"{self.model_path}/ggml-{model}.bin"
+    if not os.path.exists(file_path) and model not in MODELS:
+      raise ValueError(f"Invalid model: {model}. Must be one of {MODELS}")
+
+    self.model = model
+    self.language = language
+    self.file_path = file_path
     self.lock = asyncio.Lock()
 
   def load_model(self):
-    file_path = f"{self.model_path}/ggml-{self.model}.bin"
-    if not os.path.exists(file_path) or os.path.getsize(file_path) == 0:
+    if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
       print("Downloading model...")
       subprocess.run(["./download-ggml-model.sh", self.model, self.model_path], check=True)
       print("Done.")