Browse Source

Makes OpenAI calls awaitable

Justin 1 year ago
parent
commit
f118a23714
4 changed files with 58 additions and 62 deletions
  1. 10 14
      classes/bot.py
  2. 46 22
      classes/openai.py
  3. 1 13
      commands/classify.py
  4. 1 13
      commands/imagine.py

+ 10 - 14
classes/bot.py

@@ -746,8 +746,14 @@ class GPTBot:
         await self.matrix_client.room_read_markers(room.room_id, event.event_id)
 
         if (not from_chat_command) and self.room_uses_classification(room):
-            classification, tokens = self.classification_api.classify_message(
-                event.body, room.room_id)
+            try:
+                classification, tokens = await self.classification_api.classify_message(
+                    event.body, room.room_id)
+            except Exception as e:
+                self.logger.log(f"Error classifying message: {e}", "error")
+                await self.send_message(
+                    room, "Something went wrong. Please try again.", True)
+                return
 
             self.log_api_usage(
                 event, room, f"{self.classification_api.api_code}-{self.classification_api.classification_api}", tokens)
@@ -781,18 +787,8 @@ class GPTBot:
             chat_messages, self.max_tokens - 1, system_message=system_message)
 
         try:
-            loop = asyncio.get_event_loop()
-        except Exception as e:
-            self.logger.log(f"Error getting event loop: {e}", "error")
-            await self.send_message(
-                room, "Something went wrong. Please try again.", True)
-            return
-
-        try:
-            chat_partial = functools.partial(self.chat_api.generate_chat_response, truncated_messages, user=room.room_id)
-            response, tokens_used = await loop.run_in_executor(None, chat_partial)
-            # response, tokens_used = self.chat_api.generate_chat_response(
-            #     chat_messages, user=room.room_id)
+            response, tokens_used = await self.chat_api.generate_chat_response(
+                chat_messages, user=room.room_id)
         except Exception as e:
             self.logger.log(f"Error generating response: {e}", "error")
             await self.send_message(

+ 46 - 22
classes/openai.py

@@ -1,6 +1,8 @@
 import openai
 import requests
 
+import asyncio
+import functools
 import json
 
 from .logging import Logger
@@ -17,7 +19,7 @@ class OpenAI:
     @property
     def chat_api(self) -> str:
         return self.chat_model
-    
+
     classification_api = chat_api
     image_api: str = "dalle"
 
@@ -28,7 +30,7 @@ class OpenAI:
         self.chat_model = chat_model or self.chat_model
         self.logger = logger or Logger()
 
-    def generate_chat_response(self, messages: List[Dict[str, str]], user: Optional[str] = None) -> Tuple[str, int]:
+    async def generate_chat_response(self, messages: List[Dict[str, str]], user: Optional[str] = None) -> Tuple[str, int]:
         """Generate a response to a chat message.
 
         Args:
@@ -37,22 +39,29 @@ class OpenAI:
         Returns:
             Tuple[str, int]: The response text and the number of tokens used.
         """
+        try:
+            loop = asyncio.get_event_loop()
+        except Exception as e:
+            self.logger.log(f"Error getting event loop: {e}", "error")
+            return
 
         self.logger.log(f"Generating response to {len(messages)} messages using {self.chat_model}...")
 
-        response = openai.ChatCompletion.create(
-            model=self.chat_model,
-            messages=messages,
-            api_key=self.api_key,
-            user = user
+        chat_partial = functools.partial(
+            openai.ChatCompletion.create,
+                model=self.chat_model,
+                messages=messages,
+                api_key=self.api_key,
+                user = user
         )
+        response = await loop.run_in_executor(None, chat_partial)
 
         result_text = response.choices[0].message['content']
         tokens_used = response.usage["total_tokens"]
         self.logger.log(f"Generated response with {tokens_used} tokens.")
         return result_text, tokens_used
 
-    def classify_message(self, query: str, user: Optional[str] = None) -> Tuple[Dict[str, str], int]:
+    async def classify_message(self, query: str, user: Optional[str] = None) -> Tuple[Dict[str, str], int]:
         system_message = """You are a classifier for different types of messages. You decide whether an incoming message is meant to be a prompt for an AI chat model, or meant for a different API. You respond with a JSON object like this:
 
 { "type": event_type, "prompt": prompt }
@@ -66,10 +75,15 @@ class OpenAI:
 - If for any reason you are unable to classify the message (for example, if it infringes on your terms of service), the event_type is "error", and the prompt is a message explaining why you are unable to process the message.
 
 Only the event_types mentioned above are allowed, you must not respond in any other way."""
+        try:
+            loop = asyncio.get_event_loop()
+        except Exception as e:
+            self.logger.log(f"Error getting event loop: {e}", "error")
+            return
 
         messages = [
             {
-                "role": "system", 
+                "role": "system",
                 "content": system_message
             },
             {
@@ -80,12 +94,14 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
 
         self.logger.log(f"Classifying message '{query}'...")
 
-        response = openai.ChatCompletion.create(
-            model=self.chat_model,
-            messages=messages,
-            api_key=self.api_key,
-            user = user
+        chat_partial = functools.partial(
+            openai.ChatCompletion.create,
+                model=self.chat_model,
+                messages=messages,
+                api_key=self.api_key,
+                user=user
         )
+        response = await loop.run_in_executor(None, chat_partial)
 
         try:
             result = json.loads(response.choices[0].message['content'])
@@ -98,7 +114,7 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
 
         return result, tokens_used
 
-    def generate_image(self, prompt: str, user: Optional[str] = None) -> Generator[bytes, None, None]:
+    async def generate_image(self, prompt: str, user: Optional[str] = None) -> Generator[bytes, None, None]:
         """Generate an image from a prompt.
 
         Args:
@@ -107,16 +123,24 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
         Yields:
             bytes: The image data.
         """
+        try:
+            loop = asyncio.get_event_loop()
+        except Exception as e:
+            self.logger.log(f"Error getting event loop: {e}", "error")
+            return
+
 
         self.logger.log(f"Generating image from prompt '{prompt}'...")
 
-        response = openai.Image.create(
-            prompt=prompt,
-            n=1,
-            api_key=self.api_key,
-            size="1024x1024",
-            user = user
+        image_partial = functools.partial(
+            openai.Image.create,
+                prompt=prompt,
+                n=1,
+                api_key=self.api_key,
+                size="1024x1024",
+                user = user
         )
+        response = await loop.run_in_executor(None, image_partial)
 
         images = []
 
@@ -124,4 +148,4 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
             image = requests.get(image.url).content
             images.append(image)
 
-        return images, len(images)
+        return images, len(images)

+ 1 - 13
commands/classify.py

@@ -1,6 +1,3 @@
-import asyncio
-import functools
-
 from nio.events.room_events import RoomMessageText
 from nio.rooms import MatrixRoom
 
@@ -12,16 +9,7 @@ async def command_classify(room: MatrixRoom, event: RoomMessageText, bot):
         bot.logger.log("Classifying message...")
 
         try:
-            loop = asyncio.get_event_loop()
-        except Exception as e:
-            bot.logger.log(f"Error getting event loop: {e}", "error")
-            await bot.send_message(
-                room, "Something went wrong. Please try again.", True)
-            return
-
-        try:
-            classify_partial = functools.partial(bot.classification_api.classify_message, prompt, user=room.room_id)
-            response, tokens_used = await loop.run_in_executor(None, classify_partial)
+            response, tokens_used = await bot.classification_api.classify_message(prompt, user=room.room_id)
         except Exception as e:
             bot.logger.log(f"Error classifying message: {e}", "error")
             await bot.send_message(room, "Sorry, I couldn't classify the message. Please try again later.", True)

+ 1 - 13
commands/imagine.py

@@ -1,6 +1,3 @@
-import asyncio
-import functools
-
 from nio.events.room_events import RoomMessageText
 from nio.rooms import MatrixRoom
 
@@ -12,16 +9,7 @@ async def command_imagine(room: MatrixRoom, event: RoomMessageText, bot):
         bot.logger.log("Generating image...")
 
         try:
-            loop = asyncio.get_event_loop()
-        except Exception as e:
-            bot.logger.log(f"Error getting event loop: {e}", "error")
-            await bot.send_message(
-                room, "Something went wrong. Please try again.", True)
-            return
-
-        try:
-            image_partial = functools.partial(bot.image_api.generate_image, prompt, user=room.room_id)
-            images, tokens_used = await loop.run_in_executor(None, image_partial)
+            images, tokens_used = await bot.image_api.generate_image(prompt, user=room.room_id)
         except Exception as e:
             bot.logger.log(f"Error generating image: {e}", "error")
             await bot.send_message(room, "Sorry, I couldn't generate an image. Please try again later.", True)