diff --git a/bot.py b/bot.py index f8ab0ed..17f13a6 100644 --- a/bot.py +++ b/bot.py @@ -9,6 +9,7 @@ from typing import Optional, Union from textgen import textgen from textgen_markov import MarkovTextGenerator +from textgen_lstm import LSTMTextGenerator # TODO: Reenable and extend scraper # from models import Models @@ -57,17 +58,27 @@ class HeidiClient(discord.Client): # Textgen self.textgen_models: dict[str, textgen] = { # The name must correspond to the name of the training text file - "bibel": MarkovTextGenerator(3), # Prefix length of 3 - "kommunistisches_manifest": MarkovTextGenerator(3), - "musk": MarkovTextGenerator(3) + "kommunistisches_manifest": LSTMTextGenerator(10), + # "musk": LSTMTextGenerator(10), + # "bibel": LSTMTextGenerator(10) + + # "bibel": MarkovTextGenerator(3), # Prefix length of 3 + # "kommunistisches_manifest": MarkovTextGenerator(3), + # "musk": MarkovTextGenerator(3) } for name, model in self.textgen_models.items(): model.init(name) # Loads the textfile - if DOCKER: + + if os.path.exists(f"weights/{name}_lstm_model.pt"): model.load() - else: + elif not DOCKER: model.train() + else: + print("Error: Can't load model", name) + + print("Generating test sentence for", name) + self.textgen_models[name].generate_sentence() # Synchronize commands to guilds async def setup_hook(self):