Browse Source

Adds retry logic for failed openai requests

Justin 1 year ago
parent
commit
b41a9ecd14
1 changed files with 53 additions and 17 deletions
  1. 53 17
      classes/openai.py

+ 53 - 17
classes/openai.py

@@ -1,11 +1,13 @@
 import openai
 import requests
 
+import asyncio
 import json
+from functools import partial
 
 from .logging import Logger
 
-from typing import Dict, List, Tuple, Generator, Optional
+from typing import Dict, List, Tuple, Generator, AsyncGenerator, Optional, Any
 
 class OpenAI:
     api_key: str
@@ -28,6 +30,32 @@ class OpenAI:
         self.chat_model = chat_model or self.chat_model
         self.logger = logger or Logger()
 
+    async def _request_with_retries(self, request: partial, attempts: int = 5, retry_interval: int = 2) -> AsyncGenerator[Any | list | Dict, None]:
+        """Retry a request a set number of times if it fails.
+
+        Args:
+            request (partial): The request to make with retries.
+            attempts (int, optional): The number of attempts to make. Defaults to 5.
+            retry_interval (int, optional): The interval in seconds between attempts. Defaults to 2 seconds.
+
+        Returns:
+            AsyncGenerator[Any | list | Dict, None]: The OpenAI response for the request.
+        """
+        # call the request function and return the response if it succeeds, else retry
+        current_attempt = 1
+        while current_attempt <= attempts:
+            try:
+                response = await request()
+                return response
+            except Exception as e:
+                self.logger.log(f"Request failed: {e}", "error")
+                self.logger.log(f"Retrying in {retry_interval} seconds...")
+                await asyncio.sleep(retry_interval)
+                current_attempt += 1
+
+        # if all attempts failed, raise an exception
+        raise Exception("Request failed after all attempts.")
+
     async def generate_chat_response(self, messages: List[Dict[str, str]], user: Optional[str] = None) -> Tuple[str, int]:
         """Generate a response to a chat message.
 
@@ -39,12 +67,16 @@ class OpenAI:
         """
         self.logger.log(f"Generating response to {len(messages)} messages using {self.chat_model}...")
 
-        response  = await openai.ChatCompletion.acreate(
-            model=self.chat_model,
-            messages=messages,
-            api_key=self.api_key,
-            user = user
+
+        chat_partial = partial(
+            openai.ChatCompletion.acreate,
+                model=self.chat_model,
+                messages=messages,
+                api_key=self.api_key,
+                user=user
         )
+        response = await self._request_with_retries(chat_partial)
+
 
         result_text = response.choices[0].message['content']
         tokens_used = response.usage["total_tokens"]
@@ -78,12 +110,14 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
 
         self.logger.log(f"Classifying message '{query}'...")
 
-        response  = await openai.ChatCompletion.acreate(
-            model=self.chat_model,
-            messages=messages,
-            api_key=self.api_key,
-            user = user
+        chat_partial = partial(
+            openai.ChatCompletion.acreate,
+                model=self.chat_model,
+                messages=messages,
+                api_key=self.api_key,
+                user=user
         )
+        response = await self._request_with_retries(chat_partial)
 
         try:
             result = json.loads(response.choices[0].message['content'])
@@ -107,13 +141,15 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
         """
         self.logger.log(f"Generating image from prompt '{prompt}'...")
 
-        response = await openai.Image.acreate(
-            prompt=prompt,
-            n=1,
-            api_key=self.api_key,
-            size="1024x1024",
-            user = user
+        image_partial = partial(
+            openai.Image.acreate,
+                prompt=prompt,
+                n=1,
+                api_key=self.api_key,
+                size="1024x1024",
+                user=user
         )
+        response = await self._request_with_retries(image_partial)
 
         images = []