Browse Source

Add image input on models that support it, fix some bugs, bump required OpenAI version

Kumi 1 year ago
parent
commit
4113a02232
5 changed files with 84 additions and 26 deletions
  1. 2 2
      pyproject.toml
  2. 1 1
      requirements.txt
  3. 68 8
      src/gptbot/classes/bot.py
  4. 12 14
      src/gptbot/classes/openai.py
  5. 1 1
      src/gptbot/commands/imagine.py

+ 2 - 2
pyproject.toml

@@ -7,7 +7,7 @@ allow-direct-references = true
 
 [project]
 name = "matrix-gptbot"
-version = "0.1.1"
+version = "0.2.0"
 
 authors = [
   { name="Kumi Mitterer", email="gptbot@kumi.email" },
@@ -38,7 +38,7 @@ dependencies = [
 
 [project.optional-dependencies]
 openai = [
-    "openai",
+    "openai>=1.2",
 ]
 
 wolframalpha = [

+ 1 - 1
requirements.txt

@@ -1,4 +1,4 @@
-openai
+openai>=1.2
 matrix-nio[e2e]
 markdown2[all]
 tiktoken

+ 68 - 8
src/gptbot/classes/bot.py

@@ -27,6 +27,12 @@ from nio import (
     RoomSendError,
     RoomVisibility,
     RoomCreateError,
+    RoomMessageMedia,
+    RoomMessageImage,
+    RoomMessageFile,
+    RoomMessageAudio,
+    DownloadError,
+    DownloadResponse,
 )
 from nio.crypto import Olm
 from nio.store import SqliteStore
@@ -38,6 +44,7 @@ from io import BytesIO
 from pathlib import Path
 from contextlib import closing
 
+import base64
 import uuid
 import traceback
 import json
@@ -139,7 +146,7 @@ class GPTBot:
 
         bot.chat_api = bot.image_api = bot.classification_api = OpenAI(
             config["OpenAI"]["APIKey"], config["OpenAI"].get("Model"),
-            config["OpenAI"].get("ImageModel"), bot.logger
+            config["OpenAI"].get("ImageModel"), config["OpenAI"].get("BaseURL"),  bot.logger
         )
         bot.max_tokens = config["OpenAI"].getint("MaxTokens", bot.max_tokens)
         bot.max_messages = config["OpenAI"].getint("MaxMessages", bot.max_messages)
@@ -220,6 +227,7 @@ class GPTBot:
         for event in response.chunk:
             if len(messages) >= n:
                 break
+
             if isinstance(event, MegolmEvent):
                 try:
                     event = await self.matrix_client.decrypt_event(event)
@@ -229,14 +237,22 @@ class GPTBot:
                         "error",
                     )
                     continue
-            if isinstance(event, (RoomMessageText, RoomMessageNotice)):
+
+            if isinstance(event, RoomMessageText):
                 if event.body.startswith("!gptbot ignoreolder"):
                     break
-                if (not event.body.startswith("!")) or (
-                    event.body.startswith("!gptbot") and not ignore_bot_commands
-                ):
+                if (not event.body.startswith("!")) or (not ignore_bot_commands):
                     messages.append(event)
 
+            if isinstance(event, RoomMessageNotice):
+                if not ignore_bot_commands:
+                    messages.append(event)
+
+            if isinstance(event, RoomMessageMedia):
+                if event.sender != self.matrix_client.user_id:
+                    if len(messages) < 2 or isinstance(messages[-1], RoomMessageMedia):
+                        messages.append(event)
+
         self.logger.log(f"Found {len(messages)} messages (limit: {n})", "debug")
 
         # Reverse the list so that messages are in chronological order
@@ -275,7 +291,7 @@ class GPTBot:
         truncated_messages = []
 
         for message in [messages[0]] + list(reversed(messages[1:])):
-            content = message["content"]
+            content = message["content"] if isinstance(message["content"], str) else message["content"][0]["text"] if isinstance(message["content"][0].get("text"), str) else ""
             tokens = len(encoding.encode(content)) + 1
             if total_tokens + tokens > max_tokens:
                 break
@@ -906,14 +922,39 @@ class GPTBot:
 
         chat_messages = [{"role": "system", "content": system_message}]
 
-        for message in last_messages:
+        text_messages = list(filter(lambda x: not isinstance(x, RoomMessageMedia), last_messages))
+
+        for message in text_messages:
             role = (
                 "assistant" if message.sender == self.matrix_client.user_id else "user"
             )
             if not message.event_id == event.event_id:
                 chat_messages.append({"role": role, "content": message.body})
 
-        chat_messages.append({"role": "user", "content": event.body})
+        if not self.chat_api.supports_chat_images():
+            event_body = event.body
+        else:
+            event_body = [
+                {
+                    "type": "text",
+                    "text": event.body
+                }
+            ]
+
+            for m in list(filter(lambda x: isinstance(x, RoomMessageMedia), last_messages)):
+                image_url = m.url
+                download = await self.download_file(image_url)
+
+                if download:
+                    encoded_url = f"data:{download.content_type};base64,{base64.b64encode(download.body).decode('utf-8')}"
+                    event_body.append({
+                        "type": "image_url",
+                        "image_url": {
+                            "url": encoded_url
+                        }
+                    })
+
+        chat_messages.append({"role": "user", "content": event_body})
 
         # Truncate messages to fit within the token limit
         truncated_messages = self._truncate(
@@ -926,6 +967,7 @@ class GPTBot:
             )
         except Exception as e:
             self.logger.log(f"Error generating response: {e}", "error")
+
             await self.send_message(
                 room, "Something went wrong. Please try again.", True
             )
@@ -954,6 +996,24 @@ class GPTBot:
 
         await self.matrix_client.room_typing(room.room_id, False)
 
+    def download_file(self, mxc) -> Optional[bytes]:
+        """Download a file from the homeserver.
+
+        Args:
+            mxc (str): The MXC URI of the file to download.
+
+        Returns:
+            Optional[bytes]: The downloaded file, or None if there was an error.
+        """
+
+        download = self.matrix_client.download(mxc)
+
+        if isinstance(download, DownloadError):
+            self.logger.log(f"Error downloading file: {download.message}", "error")
+            return
+
+        return download
+
     def get_system_message(self, room: MatrixRoom | str) -> str:
         """Get the system message for a room.
 

+ 12 - 14
src/gptbot/classes/openai.py

@@ -25,12 +25,13 @@ class OpenAI:
 
     operator: str = "OpenAI ([https://openai.com](https://openai.com))"
 
-    def __init__(self, api_key, chat_model=None, image_model=None, logger=None):
+    def __init__(self, api_key, chat_model=None, image_model=None, base_url=None, logger=None):
         self.api_key = api_key
         self.chat_model = chat_model or self.chat_model
         self.image_model = image_model or self.image_model
         self.logger = logger or Logger()
-        self.base_url = openai.api_base
+        self.base_url = base_url or openai.base_url
+        self.openai_api = openai.AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
 
     def supports_chat_images(self):
         return "vision" in self.chat_model
@@ -74,18 +75,20 @@ class OpenAI:
 
 
         chat_partial = partial(
-            openai.ChatCompletion.acreate,
+            self.openai_api.chat.completions.create,
                 model=self.chat_model,
                 messages=messages,
-                api_key=self.api_key,
                 user=user,
-                api_base=self.base_url,
+                max_tokens=4096
         )
         response = await self._request_with_retries(chat_partial)
 
+        self.logger.log(response, "error")
+        self.logger.log(response.choices, "error")
+        self.logger.log(response.choices[0].message, "error")
 
-        result_text = response.choices[0].message['content']
-        tokens_used = response.usage["total_tokens"]
+        result_text = response.choices[0].message.content
+        tokens_used = response.usage.total_tokens
         self.logger.log(f"Generated response with {tokens_used} tokens.")
         return result_text, tokens_used
 
@@ -117,13 +120,10 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
         self.logger.log(f"Classifying message '{query}'...")
 
         chat_partial = partial(
-            openai.ChatCompletion.acreate,
+            self.openai_api.chat.completions.create,
                 model=self.chat_model,
                 messages=messages,
-                api_key=self.api_key,
                 user=user,
-                api_base=self.base_url,
-                quality=("hd" if model == "dall-e-3" else "normal")
         )
         response = await self._request_with_retries(chat_partial)
 
@@ -150,14 +150,12 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
         self.logger.log(f"Generating image from prompt '{prompt}'...")
 
         image_partial = partial(
-            openai.Image.acreate,
+            self.openai_api.images.generate,
                 model=self.image_model,
                 prompt=prompt,
                 n=1,
-                api_key=self.api_key,
                 size="1024x1024",
                 user=user,
-                api_base=self.base_url,
         )
         response = await self._request_with_retries(image_partial)
 

+ 1 - 1
src/gptbot/commands/imagine.py

@@ -19,7 +19,7 @@ async def command_imagine(room: MatrixRoom, event: RoomMessageText, bot):
             bot.logger.log(f"Sending image...")
             await bot.send_image(room, image)
 
-        bot.log_api_usage(event, room, f"{bot.image_api.api_code}-{bot.image_api.image_api}", tokens_used)
+        bot.log_api_usage(event, room, f"{bot.image_api.api_code}-{bot.image_api.image_model}", tokens_used)
 
         return