bug fixes

This commit is contained in:
Dustin 2024-08-07 00:15:49 -04:00
parent dcc41fef6b
commit e746154626

View File

@ -19,10 +19,11 @@ class ollamarama:
config = json.load(f) config = json.load(f)
f.close() f.close()
self.server, self.username, self.password, self.channels, self.admins = config['matrix'].values() self.server, self.username, self.password, self.channels, self.admins = config["matrix"].values()
self.api_url = config['ollama']['api_base'] + "/api/chat" self.api_url = config["ollama"]["api_base"] + "/api/chat"
self.default_personality = config['ollama']['personality']
self.default_personality = config["ollama"]["personality"]
self.personality = self.default_personality self.personality = self.default_personality
self.client = AsyncClient(self.server, self.username) self.client = AsyncClient(self.server, self.username)
@ -34,14 +35,13 @@ class ollamarama:
self.messages = {} self.messages = {}
#prompt parts #prompt parts
self.prompt = config['ollama']['prompt'] self.prompt = config["ollama"]["prompt"]
self.models = config['ollama']['models'] self.models = config["ollama"]["models"]
#set model self.default_model = self.models[config["ollama"]["default_model"]]
self.default_model = self.models[config['ollama']['default_model']]
self.model = self.default_model self.model = self.default_model
self.temperature, self.top_p, self.repeat_penalty = config['ollama']['options'].values() self.temperature, self.top_p, self.repeat_penalty = config["ollama"]["options"].values()
self.defaults = { self.defaults = {
"temperature": self.temperature, "temperature": self.temperature,
"top_p": self.top_p, "top_p": self.top_p,
@ -107,7 +107,7 @@ class ollamarama:
"repeat_penalty": self.repeat_penalty "repeat_penalty": self.repeat_penalty
} }
} }
response = requests.post(self.api_url, json=data) response = requests.post(self.api_url, json=data, timeout=60)
response.raise_for_status() response.raise_for_status()
data = response.json() data = response.json()
@ -116,7 +116,7 @@ class ollamarama:
print(e) print(e)
else: else:
#Extract response text #Extract response text
response_text = data["message"]['content'] response_text = data["message"]["content"]
#add to history #add to history
await self.add_history("assistant", channel, sender, response_text) await self.add_history("assistant", channel, sender, response_text)
@ -134,7 +134,7 @@ class ollamarama:
print(e) print(e)
#Shrink history list for token size management #Shrink history list for token size management
if len(self.messages[channel][sender]) > 24: if len(self.messages[channel][sender]) > 24:
if self.messages[channel][sender][0]['role'] == 'system': if self.messages[channel][sender][0]["role"] == "system":
del self.messages[channel][sender][1:3] del self.messages[channel][sender][1:3]
else: else:
del self.messages[channel][sender][0:2] del self.messages[channel][sender][0:2]
@ -149,6 +149,7 @@ class ollamarama:
personality = self.prompt[0] + persona + self.prompt[1] personality = self.prompt[0] + persona + self.prompt[1]
#set system prompt #set system prompt
await self.add_history("system", channel, sender, personality) await self.add_history("system", channel, sender, personality)
await self.add_history("user", channel, sender, "introduce yourself")
# use a custom prompt # use a custom prompt
async def custom(self, channel, sender, prompt): async def custom(self, channel, sender, prompt):
@ -184,7 +185,7 @@ class ollamarama:
with open(self.config_file, "r") as f: with open(self.config_file, "r") as f:
config = json.load(f) config = json.load(f)
f.close() f.close()
self.models = config['ollama']['models'] self.models = config["ollama"]["models"]
if message == ".models": if message == ".models":
current_model = f"Current model: {self.model}\nAvailable models: {', '.join(sorted(list(self.models)))}" current_model = f"Current model: {self.model}\nAvailable models: {', '.join(sorted(list(self.models)))}"
await self.send_message(room_id, current_model) await self.send_message(room_id, current_model)
@ -239,9 +240,9 @@ class ollamarama:
if message.startswith((".temperature ", ".top_p ", ".repeat_penalty ")): if message.startswith((".temperature ", ".top_p ", ".repeat_penalty ")):
attr_name = message.split()[0][1:] attr_name = message.split()[0][1:]
min_val, max_val, default_val = { min_val, max_val, default_val = {
"temperature": (0, 1, self.defaults['temperature']), "temperature": (0, 1, self.defaults["temperature"]),
"top_p": (0, 1, self.defaults['top_p']), "top_p": (0, 1, self.defaults["top_p"]),
"repeat_penalty": (0, 2, self.defaults['repeat_penalty']) "repeat_penalty": (0, 2, self.defaults["repeat_penalty"])
}[attr_name] }[attr_name]
if message.endswith(" reset"): if message.endswith(" reset"):