Browse Source

Merge branch 'justin-russell-bugfixes'

Kumi 1 year ago
parent
commit
3a1d1ea86a
5 changed files with 89 additions and 41 deletions
  1. 14 9
      classes/bot.py
  2. 59 26
      classes/openai.py
  3. 7 2
      commands/classify.py
  4. 7 2
      commands/imagine.py
  5. 2 2
      migrations/__init__.py

+ 14 - 9
classes/bot.py

@@ -1,8 +1,8 @@
 import markdown2
 import duckdb
 import tiktoken
-import magic
 import asyncio
+import functools
 
 from PIL import Image
 
@@ -27,12 +27,11 @@ from nio import (
     RoomLeaveError,
     RoomSendError,
     RoomVisibility,
-    RoomCreateResponse,
     RoomCreateError,
 )
 from nio.crypto import Olm
 
-from typing import Optional, List, Dict, Tuple
+from typing import Optional, List
 from configparser import ConfigParser
 from datetime import datetime
 from io import BytesIO
@@ -174,7 +173,7 @@ class GPTBot:
 
     async def _last_n_messages(self, room: str | MatrixRoom, n: Optional[int]):
         messages = []
-        n = n or bot.max_messages
+        n = n or self.max_messages
         room_id = room.room_id if isinstance(room, MatrixRoom) else room
 
         self.logger.log(
@@ -585,7 +584,7 @@ class GPTBot:
             self.logger.log(
                 "No database connection set up, using in-memory database. Data will be lost on bot shutdown.")
             IN_MEMORY = True
-            self.database = DuckDBPyConnection(":memory:")
+            self.database = duckdb.DuckDBPyConnection(":memory:")
 
         self.logger.log("Running migrations...")
         before, after = migrate(self.database)
@@ -747,8 +746,14 @@ class GPTBot:
         await self.matrix_client.room_read_markers(room.room_id, event.event_id)
 
         if (not from_chat_command) and self.room_uses_classification(room):
-            classification, tokens = self.classification_api.classify_message(
-                event.body, room.room_id)
+            try:
+                classification, tokens = await self.classification_api.classify_message(
+                    event.body, room.room_id)
+            except Exception as e:
+                self.logger.log(f"Error classifying message: {e}", "error")
+                await self.send_message(
+                    room, "Something went wrong. Please try again.", True)
+                return
 
             self.log_api_usage(
                 event, room, f"{self.classification_api.api_code}-{self.classification_api.classification_api}", tokens)
@@ -782,7 +787,7 @@ class GPTBot:
             chat_messages, self.max_tokens - 1, system_message=system_message)
 
         try:
-            response, tokens_used = self.chat_api.generate_chat_response(
+            response, tokens_used = await self.chat_api.generate_chat_response(
                 chat_messages, user=room.room_id)
         except Exception as e:
             self.logger.log(f"Error generating response: {e}", "error")
@@ -803,7 +808,7 @@ class GPTBot:
         else:
             # Send a notice to the room if there was an error
             self.logger.log("Didn't get a response from GPT API", "error")
-            await send_message(
+            await self.send_message(
                 room, "Something went wrong. Please try again.", True)
 
         await self.matrix_client.room_typing(room.room_id, False)

+ 59 - 26
classes/openai.py

@@ -1,11 +1,13 @@
 import openai
 import requests
 
+import asyncio
 import json
+from functools import partial
 
 from .logging import Logger
 
-from typing import Dict, List, Tuple, Generator, Optional
+from typing import Dict, List, Tuple, Generator, AsyncGenerator, Optional, Any
 
 class OpenAI:
     api_key: str
@@ -17,7 +19,7 @@ class OpenAI:
     @property
     def chat_api(self) -> str:
         return self.chat_model
-    
+
     classification_api = chat_api
     image_api: str = "dalle"
 
@@ -28,7 +30,33 @@ class OpenAI:
         self.chat_model = chat_model or self.chat_model
         self.logger = logger or Logger()
 
-    def generate_chat_response(self, messages: List[Dict[str, str]], user: Optional[str] = None) -> Tuple[str, int]:
+    async def _request_with_retries(self, request: partial, attempts: int = 5, retry_interval: int = 2) -> AsyncGenerator[Any | list | Dict, None]:
+        """Retry a request a set number of times if it fails.
+
+        Args:
+            request (partial): The request to make with retries.
+            attempts (int, optional): The number of attempts to make. Defaults to 5.
+            retry_interval (int, optional): The interval in seconds between attempts. Defaults to 2 seconds.
+
+        Returns:
+            AsyncGenerator[Any | list | Dict, None]: The OpenAI response for the request.
+        """
+        # call the request function and return the response if it succeeds, else retry
+        current_attempt = 1
+        while current_attempt <= attempts:
+            try:
+                response = await request()
+                return response
+            except Exception as e:
+                self.logger.log(f"Request failed: {e}", "error")
+                self.logger.log(f"Retrying in {retry_interval} seconds...")
+                await asyncio.sleep(retry_interval)
+                current_attempt += 1
+
+        # if all attempts failed, raise an exception
+        raise Exception("Request failed after all attempts.")
+
+    async def generate_chat_response(self, messages: List[Dict[str, str]], user: Optional[str] = None) -> Tuple[str, int]:
         """Generate a response to a chat message.
 
         Args:
@@ -37,22 +65,25 @@ class OpenAI:
         Returns:
             Tuple[str, int]: The response text and the number of tokens used.
         """
-
         self.logger.log(f"Generating response to {len(messages)} messages using {self.chat_model}...")
 
-        response = openai.ChatCompletion.create(
-            model=self.chat_model,
-            messages=messages,
-            api_key=self.api_key,
-            user = user
+
+        chat_partial = partial(
+            openai.ChatCompletion.acreate,
+                model=self.chat_model,
+                messages=messages,
+                api_key=self.api_key,
+                user=user
         )
+        response = await self._request_with_retries(chat_partial)
+
 
         result_text = response.choices[0].message['content']
         tokens_used = response.usage["total_tokens"]
         self.logger.log(f"Generated response with {tokens_used} tokens.")
         return result_text, tokens_used
 
-    def classify_message(self, query: str, user: Optional[str] = None) -> Tuple[Dict[str, str], int]:
+    async def classify_message(self, query: str, user: Optional[str] = None) -> Tuple[Dict[str, str], int]:
         system_message = """You are a classifier for different types of messages. You decide whether an incoming message is meant to be a prompt for an AI chat model, or meant for a different API. You respond with a JSON object like this:
 
 { "type": event_type, "prompt": prompt }
@@ -66,10 +97,9 @@ class OpenAI:
 - If for any reason you are unable to classify the message (for example, if it infringes on your terms of service), the event_type is "error", and the prompt is a message explaining why you are unable to process the message.
 
 Only the event_types mentioned above are allowed, you must not respond in any other way."""
-
         messages = [
             {
-                "role": "system", 
+                "role": "system",
                 "content": system_message
             },
             {
@@ -80,12 +110,14 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
 
         self.logger.log(f"Classifying message '{query}'...")
 
-        response = openai.ChatCompletion.create(
-            model=self.chat_model,
-            messages=messages,
-            api_key=self.api_key,
-            user = user
+        chat_partial = partial(
+            openai.ChatCompletion.acreate,
+                model=self.chat_model,
+                messages=messages,
+                api_key=self.api_key,
+                user=user
         )
+        response = await self._request_with_retries(chat_partial)
 
         try:
             result = json.loads(response.choices[0].message['content'])
@@ -98,7 +130,7 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
 
         return result, tokens_used
 
-    def generate_image(self, prompt: str, user: Optional[str] = None) -> Generator[bytes, None, None]:
+    async def generate_image(self, prompt: str, user: Optional[str] = None) -> Generator[bytes, None, None]:
         """Generate an image from a prompt.
 
         Args:
@@ -107,16 +139,17 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
         Yields:
             bytes: The image data.
         """
-
         self.logger.log(f"Generating image from prompt '{prompt}'...")
 
-        response = openai.Image.create(
-            prompt=prompt,
-            n=1,
-            api_key=self.api_key,
-            size="1024x1024",
-            user = user
+        image_partial = partial(
+            openai.Image.acreate,
+                prompt=prompt,
+                n=1,
+                api_key=self.api_key,
+                size="1024x1024",
+                user=user
         )
+        response = await self._request_with_retries(image_partial)
 
         images = []
 
@@ -124,4 +157,4 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
             image = requests.get(image.url).content
             images.append(image)
 
-        return images, len(images)
+        return images, len(images)

+ 7 - 2
commands/classify.py

@@ -8,7 +8,12 @@ async def command_classify(room: MatrixRoom, event: RoomMessageText, bot):
     if prompt:
         bot.logger.log("Classifying message...")
 
-        response, tokens_used = bot.classification_api.classify_message(prompt, user=room.room_id)
+        try:
+            response, tokens_used = await bot.classification_api.classify_message(prompt, user=room.room_id)
+        except Exception as e:
+            bot.logger.log(f"Error classifying message: {e}", "error")
+            await bot.send_message(room, "Sorry, I couldn't classify the message. Please try again later.", True)
+            return
 
         message = f"The message you provided seems to be of type: {response['type']}."
 
@@ -21,4 +26,4 @@ async def command_classify(room: MatrixRoom, event: RoomMessageText, bot):
 
         return
 
-    await bot.send_message(room, "You need to provide a prompt.", True)
+    await bot.send_message(room, "You need to provide a prompt.", True)

+ 7 - 2
commands/imagine.py

@@ -8,7 +8,12 @@ async def command_imagine(room: MatrixRoom, event: RoomMessageText, bot):
     if prompt:
         bot.logger.log("Generating image...")
 
-        images, tokens_used = bot.image_api.generate_image(prompt, user=room.room_id)
+        try:
+            images, tokens_used = await bot.image_api.generate_image(prompt, user=room.room_id)
+        except Exception as e:
+            bot.logger.log(f"Error generating image: {e}", "error")
+            await bot.send_message(room, "Sorry, I couldn't generate an image. Please try again later.", True)
+            return
 
         for image in images:
             bot.logger.log(f"Sending image...")
@@ -18,4 +23,4 @@ async def command_imagine(room: MatrixRoom, event: RoomMessageText, bot):
 
         return
 
-    await bot.send_message(room, "You need to provide a prompt.", True)
+    await bot.send_message(room, "You need to provide a prompt.", True)

+ 2 - 2
migrations/__init__.py

@@ -45,7 +45,7 @@ def migrate(db: DuckDBPyConnection, from_version: Optional[int] = None, to_versi
         raise ValueError("Cannot migrate from a higher version to a lower version.")
 
     for version in range(from_version, to_version):
-        if version in MIGRATIONS:
+        if version + 1 in MIGRATIONS:
             MIGRATIONS[version + 1](db)
 
-    return from_version, to_version
+    return from_version, to_version