Prechádzať zdrojové kódy

Allow setting BaseURL for OpenAI API

Kumi 1 rok pred
rodič
commit
b1b274be57

+ 7 - 0
config.dist.ini

@@ -93,6 +93,13 @@ APIKey = sk-yoursecretkey
 #
 # MaxMessages = 20
 
+# The base URL of the OpenAI API
+#
+# Setting this allows you to use a self-hosted AI model for chat completions
+# using something like https://github.com/abetlen/llama-cpp-python
+#
+# BaseURL = https://openai.local/v1
+
 ###############################################################################
 
 [WolframAlpha]

+ 4 - 0
src/gptbot/classes/bot.py

@@ -143,6 +143,10 @@ class GPTBot:
         bot.max_tokens = config["OpenAI"].getint("MaxTokens", bot.max_tokens)
         bot.max_messages = config["OpenAI"].getint("MaxMessages", bot.max_messages)
 
+        if "BaseURL" in config["OpenAI"]:
+            bot.chat_api.base_url = config["OpenAI"]["BaseURL"]
+            bot.image_api = None
+
         # Set up WolframAlpha
         if "WolframAlpha" in config:
             bot.calculation_api = WolframAlpha(

+ 7 - 3
src/gptbot/classes/openai.py

@@ -29,6 +29,7 @@ class OpenAI:
         self.api_key = api_key
         self.chat_model = chat_model or self.chat_model
         self.logger = logger or Logger()
+        self.base_url = openai.api_base
 
     async def _request_with_retries(self, request: partial, attempts: int = 5, retry_interval: int = 2) -> AsyncGenerator[Any | list | Dict, None]:
         """Retry a request a set number of times if it fails.
@@ -73,7 +74,8 @@ class OpenAI:
                 model=self.chat_model,
                 messages=messages,
                 api_key=self.api_key,
-                user=user
+                user=user,
+                api_base=self.base_url,
         )
         response = await self._request_with_retries(chat_partial)
 
@@ -115,7 +117,8 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
                 model=self.chat_model,
                 messages=messages,
                 api_key=self.api_key,
-                user=user
+                user=user,
+                api_base=self.base_url,
         )
         response = await self._request_with_retries(chat_partial)
 
@@ -147,7 +150,8 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
                 n=1,
                 api_key=self.api_key,
                 size="1024x1024",
-                user=user
+                user=user,
+                api_base=self.base_url,
         )
         response = await self._request_with_retries(image_partial)