enable lstm in bot

This commit is contained in:
2022-11-13 14:26:39 +01:00
parent d5f1fbfbc4
commit 9e3b217331

21
bot.py
View File

@ -9,6 +9,7 @@ from typing import Optional, Union
from textgen import textgen from textgen import textgen
from textgen_markov import MarkovTextGenerator from textgen_markov import MarkovTextGenerator
from textgen_lstm import LSTMTextGenerator
# TODO: Reenable and extend scraper # TODO: Reenable and extend scraper
# from models import Models # from models import Models
@ -57,17 +58,27 @@ class HeidiClient(discord.Client):
# Textgen # Textgen
self.textgen_models: dict[str, textgen] = { self.textgen_models: dict[str, textgen] = {
# The name must correspond to the name of the training text file # The name must correspond to the name of the training text file
"bibel": MarkovTextGenerator(3), # Prefix length of 3 "kommunistisches_manifest": LSTMTextGenerator(10),
"kommunistisches_manifest": MarkovTextGenerator(3), # "musk": LSTMTextGenerator(10),
"musk": MarkovTextGenerator(3) # "bibel": LSTMTextGenerator(10)
# "bibel": MarkovTextGenerator(3), # Prefix length of 3
# "kommunistisches_manifest": MarkovTextGenerator(3),
# "musk": MarkovTextGenerator(3)
} }
for name, model in self.textgen_models.items(): for name, model in self.textgen_models.items():
model.init(name) # Loads the textfile model.init(name) # Loads the textfile
if DOCKER:
if os.path.exists(f"weights/{name}_lstm_model.pt"):
model.load() model.load()
else: elif not DOCKER:
model.train() model.train()
else:
print("Error: Can't load model", name)
print("Generating test sentence for", name)
self.textgen_models[name].generate_sentence()
# Synchronize commands to guilds # Synchronize commands to guilds
async def setup_hook(self): async def setup_hook(self):