Przeglądaj źródła

add video support

subDesTagesMitExtraKaese 2 lat temu
rodzic
commit
04e68651fd
2 zmienionych plików z 37 dodań i 25 usunięć
  1. 14 8
      main.py
  2. 23 17
      speech_recognition.py

+ 14 - 8
main.py

@@ -25,15 +25,16 @@ bot = botlib.Bot(creds, config)
 
 asr = ASR(os.getenv('ASR_MODEL', 'tiny'), os.getenv('ASR_LANGUAGE', 'en'))
 
-@bot.listener.on_custom_event(nio.RoomMessageAudio)
-async def on_message_audio(room, event):
-  await on_audio(room, event, False)
+@bot.listener.on_custom_event(nio.RoomMessage)
+async def on_message(room, event):
+  if not isinstance(event, (nio.RoomMessageAudio,
+                            nio.RoomEncryptedAudio,
+                            nio.RoomMessageVideo,
+                            nio.RoomEncryptedVideo)):
+    return
 
-@bot.listener.on_custom_event(nio.RoomEncryptedAudio)
-async def on_encrypted_audio(room, event):
-  await on_audio(room, event, True)
+  encrypted = isinstance(event, (nio.RoomEncryptedAudio, nio.RoomEncryptedVideo))
 
-async def on_audio(room, event, encrypted):
   print(room.machine_name, event.sender, event.body, event.url)
   match = botlib.MessageMatch(room, event, bot)
   if match.is_not_from_this_bot():
@@ -56,6 +57,11 @@ async def on_audio(room, event, encrypted):
     result = await asr.transcribe(data)
 
     await bot.async_client.room_typing(room.machine_name, False)
+
+    if not result:
+      print("No result")
+      return
+
     filename = response.filename or event.body
     if filename:
       reply = f"Transcription of {filename}: {result}"
@@ -63,7 +69,7 @@ async def on_audio(room, event, encrypted):
       reply = f"Transcription: {result}"
 
     await bot.api._send_room(
-        room_id=room.room_id,
+      room_id=room.room_id,
       content={
         "msgtype": "m.notice",
         "body": reply,

+ 23 - 17
speech_recognition.py

@@ -1,3 +1,4 @@
+import tempfile
 import ffmpeg
 import asyncio
 import subprocess
@@ -5,15 +6,19 @@ import os
 
 SAMPLE_RATE = 16000
 
-def convert_audio(data: bytes) -> bytes:
+def convert_audio(data: bytes, out_filename: str):
   try:
-    # This launches a subprocess to decode audio while down-mixing and resampling as necessary.
-    # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
-    out, _ = (
-      ffmpeg.input("pipe:", threads=0)
-      .output("audio.wav", format="wav", acodec="pcm_s16le", ac=1, ar=SAMPLE_RATE)
-      .run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True, input=data)
-    )
+    with tempfile.NamedTemporaryFile("w+b") as file:
+      file.write(data)
+      file.flush()
+      print(f"Converting media {file.name} to {out_filename}")
+
+      out, _ = (
+        ffmpeg.input(file.name, threads=0)
+        .output(out_filename, format="wav", acodec="pcm_s16le", ac=1, ar=SAMPLE_RATE)
+        .overwrite_output()
+        .run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True, input=data)
+      )
   except ffmpeg.Error as e:
     raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
 
@@ -41,25 +46,26 @@ class ASR():
       print("Done.")
 
   async def transcribe(self, audio: bytes) -> str:
+    filename = tempfile.mktemp(suffix=".wav")
+    convert_audio(audio, filename)
     async with self.lock:
-      convert_audio(audio)
       proc = await asyncio.create_subprocess_exec(
           "./main",
           "-m", self.model_path,
           "-l", self.language,
-          "-f", "audio.wav",
+          "-f", filename,
           "--no_timestamps", 
           stdout=asyncio.subprocess.PIPE,
           stderr=asyncio.subprocess.PIPE
         )
       stdout, stderr = await proc.communicate()
 
-      os.remove("audio.wav")
+      os.remove(filename)
 
-      if stderr:
-        print(stderr.decode())
-        
-      text = stdout.decode()
-      print(text)
+    if stderr:
+      print(stderr.decode())
+      
+    text = stdout.decode().strip()
+    print(text)
 
-      return text
+    return text