Merge pull request #20 from h1ddenpr0cess20/dev

bug fixes
This commit is contained in:
Dustin 2024-08-07 00:16:51 -04:00 committed by GitHub
commit a65b6ab340
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -19,10 +19,11 @@ class ollamarama:
config = json.load(f)
f.close()
self.server, self.username, self.password, self.channels, self.admins = config['matrix'].values()
self.server, self.username, self.password, self.channels, self.admins = config["matrix"].values()
self.api_url = config['ollama']['api_base'] + "/api/chat"
self.default_personality = config['ollama']['personality']
self.api_url = config["ollama"]["api_base"] + "/api/chat"
self.default_personality = config["ollama"]["personality"]
self.personality = self.default_personality
self.client = AsyncClient(self.server, self.username)
@ -34,14 +35,13 @@ class ollamarama:
self.messages = {}
#prompt parts
self.prompt = config['ollama']['prompt']
self.prompt = config["ollama"]["prompt"]
self.models = config['ollama']['models']
#set model
self.default_model = self.models[config['ollama']['default_model']]
self.models = config["ollama"]["models"]
self.default_model = self.models[config["ollama"]["default_model"]]
self.model = self.default_model
self.temperature, self.top_p, self.repeat_penalty = config['ollama']['options'].values()
self.temperature, self.top_p, self.repeat_penalty = config["ollama"]["options"].values()
self.defaults = {
"temperature": self.temperature,
"top_p": self.top_p,
@ -107,7 +107,7 @@ class ollamarama:
"repeat_penalty": self.repeat_penalty
}
}
response = requests.post(self.api_url, json=data)
response = requests.post(self.api_url, json=data, timeout=60)
response.raise_for_status()
data = response.json()
@ -116,7 +116,7 @@ class ollamarama:
print(e)
else:
#Extract response text
response_text = data["message"]['content']
response_text = data["message"]["content"]
#add to history
await self.add_history("assistant", channel, sender, response_text)
@ -134,7 +134,7 @@ class ollamarama:
print(e)
#Shrink history list for token size management
if len(self.messages[channel][sender]) > 24:
if self.messages[channel][sender][0]['role'] == 'system':
if self.messages[channel][sender][0]["role"] == "system":
del self.messages[channel][sender][1:3]
else:
del self.messages[channel][sender][0:2]
@ -149,7 +149,8 @@ class ollamarama:
personality = self.prompt[0] + persona + self.prompt[1]
#set system prompt
await self.add_history("system", channel, sender, personality)
await self.add_history("user", channel, sender, "introduce yourself")
# use a custom prompt
async def custom(self, channel, sender, prompt):
try:
@ -184,7 +185,7 @@ class ollamarama:
with open(self.config_file, "r") as f:
config = json.load(f)
f.close()
self.models = config['ollama']['models']
self.models = config["ollama"]["models"]
if message == ".models":
current_model = f"Current model: {self.model}\nAvailable models: {', '.join(sorted(list(self.models)))}"
await self.send_message(room_id, current_model)
@ -239,9 +240,9 @@ class ollamarama:
if message.startswith((".temperature ", ".top_p ", ".repeat_penalty ")):
attr_name = message.split()[0][1:]
min_val, max_val, default_val = {
"temperature": (0, 1, self.defaults['temperature']),
"top_p": (0, 1, self.defaults['top_p']),
"repeat_penalty": (0, 2, self.defaults['repeat_penalty'])
"temperature": (0, 1, self.defaults["temperature"]),
"top_p": (0, 1, self.defaults["top_p"]),
"repeat_penalty": (0, 2, self.defaults["repeat_penalty"])
}[attr_name]
if message.endswith(" reset"):