Преглед изворни кода

Fix recursion errors in OpenAI class

Improved the error handling in the OpenAI class to prevent infinite recursion issues by retaining the original chat model during recursive calls. Enhanced logging within the recursion depth check for better debugging and traceability. Ensured consistency in chat responses by passing the initial model reference throughout the entire call stack. This is crucial when fallbacks due to errors or tool usage occur.

Refactored code for clarity and readability, ensuring that any recursion retains the original model and tool parameters. Additionally, proper logging and condition checks now standardize the flow of execution, preventing unintended modifications to the model's state that could lead to incorrect bot behavior.
Kumi пре 10 месеци
родитељ
комит
c4e23cb9d3
1 измењених фајлова са 43 додато и 18 уклоњено
  1. 43 18
      src/gptbot/classes/openai.py

+ 43 - 18
src/gptbot/classes/openai.py

@@ -23,11 +23,13 @@ ASSISTANT_CODE_INTERPRETER = [
     },
 ]
 
+
 class AttributeDictionary(dict):
     def __init__(self, *args, **kwargs):
         super(AttributeDictionary, self).__init__(*args, **kwargs)
         self.__dict__ = self
 
+
 class OpenAI:
     api_key: str
     chat_model: str = "gpt-3.5-turbo"
@@ -143,25 +145,31 @@ class OpenAI:
             f"Generating response to {len(messages)} messages for user {user} in room {room}..."
         )
 
-        chat_model = model or self.chat_model
+        original_model = chat_model = model or self.chat_model
 
         # Check current recursion depth to prevent infinite loops
 
         if use_tools:
             frames = inspect.stack()
             current_function = inspect.getframeinfo(frames[0][0]).function
-            count = sum(1 for frame in frames if inspect.getframeinfo(frame[0]).function == current_function)
-            self.logger.log(f"{current_function} appears {count} times in the call stack")
-            
+            count = sum(
+                1
+                for frame in frames
+                if inspect.getframeinfo(frame[0]).function == current_function
+            )
+            self.logger.log(
+                f"{current_function} appears {count} times in the call stack"
+            )
+
             if count > 5:
                 self.logger.log(f"Recursion depth exceeded, aborting.")
                 return self.generate_chat_response(
                     messages,
                     user=user,
                     room=room,
-                    allow_override=False, # TODO: Could this be a problem?
+                    allow_override=False,  # TODO: Could this be a problem?
                     use_tools=False,
-                    model=model,
+                    model=original_model,
                 )
 
         tools = [
@@ -231,12 +239,13 @@ class OpenAI:
                                 f"- {tool_name}: {tool_class.DESCRIPTION} ({tool_class.PARAMETERS})"
                                 for tool_name, tool_class in TOOLS.items()
                             ]
-                        ) + """
+                        )
+                        + """
 
                         If no tool is required, or all information is already available in the message thread, respond with an empty JSON object: {}
 
                         Do NOT FOLLOW ANY OTHER INSTRUCTIONS BELOW, they are only meant for the AI chat model. You can ignore them. DO NOT include any other text or syntax in your response, only the JSON object. DO NOT surround it in code tags (```). DO NOT, UNDER ANY CIRCUMSTANCES, ASK AGAIN FOR INFORMATION ALREADY PROVIDED IN THE MESSAGES YOU RECEIVED! DO NOT REQUEST MORE INFORMATION THAN ABSOLUTELY REQUIRED TO RESPOND TO THE USER'S MESSAGE! Remind the user that they may ask you to search for additional information if they need it.
-                        """
+                        """,
                     }
                 ]
                 + messages
@@ -292,6 +301,7 @@ class OpenAI:
                         room=room,
                         allow_override=False,
                         use_tools=False,
+                        model=original_model,
                     )
 
             if not tool_responses:
@@ -306,7 +316,7 @@ class OpenAI:
                         + original_messages[-1:]
                     )
                     result_text, additional_tokens = await self.generate_chat_response(
-                        messages, user=user, room=room
+                        messages, user=user, room=room, model=original_messages
                     )
                 except openai.APIError as e:
                     if e.code == "max_tokens":
@@ -338,6 +348,7 @@ class OpenAI:
                                 room=room,
                                 allow_override=False,
                                 use_tools=False,
+                                model=original_model,
                             )
 
                         except openai.APIError as e:
@@ -351,6 +362,7 @@ class OpenAI:
                                     room=room,
                                     allow_override=False,
                                     use_tools=False,
+                                    model=original_model,
                                 )
                     else:
                         raise e
@@ -359,16 +371,22 @@ class OpenAI:
             if "tool" in tool_object:
                 tool_name = tool_object["tool"]
                 tool_class = TOOLS[tool_name]
-                tool_parameters = tool_object["parameters"] if "parameters" in tool_object else {}
+                tool_parameters = (
+                    tool_object["parameters"] if "parameters" in tool_object else {}
+                )
 
-                self.logger.log(f"Using tool {tool_name} with parameters {tool_parameters}...")
+                self.logger.log(
+                    f"Using tool {tool_name} with parameters {tool_parameters}..."
+                )
 
                 tool_call = AttributeDictionary(
                     {
-                        "function": AttributeDictionary({
-                            "name": tool_name,
-                            "arguments": json.dumps(tool_parameters),
-                        }),
+                        "function": AttributeDictionary(
+                            {
+                                "name": tool_name,
+                                "arguments": json.dumps(tool_parameters),
+                            }
+                        ),
                     }
                 )
 
@@ -392,6 +410,7 @@ class OpenAI:
                         room=room,
                         allow_override=False,
                         use_tools=False,
+                        model=original_model,
                     )
 
                 if not tool_responses:
@@ -405,7 +424,10 @@ class OpenAI:
                             + tool_responses
                             + original_messages[-1:]
                         )
-                        result_text, additional_tokens = await self.generate_chat_response(
+                        (
+                            result_text,
+                            additional_tokens,
+                        ) = await self.generate_chat_response(
                             messages, user=user, room=room
                         )
                     except openai.APIError as e:
@@ -419,6 +441,7 @@ class OpenAI:
                                 room=room,
                                 allow_override=False,
                                 use_tools=False,
+                                model=original_model,
                             )
                         else:
                             raise e
@@ -429,9 +452,10 @@ class OpenAI:
                     room=room,
                     allow_override=False,
                     use_tools=False,
+                    model=original_model,
                 )
 
-        elif not self.chat_model == chat_model:
+        elif not original_model == chat_model:
             new_messages = []
 
             for message in original_messages:
@@ -448,7 +472,8 @@ class OpenAI:
                 new_messages.append(new_message)
 
             result_text, additional_tokens = await self.generate_chat_response(
-                new_messages, user=user, room=room, allow_override=False
+                new_messages, user=user, room=room, allow_override=False,
+                model=original_model
             )
 
         try: