enable lstm in bot
This commit is contained in:
21
bot.py
21
bot.py
@ -9,6 +9,7 @@ from typing import Optional, Union
|
|||||||
|
|
||||||
from textgen import textgen
|
from textgen import textgen
|
||||||
from textgen_markov import MarkovTextGenerator
|
from textgen_markov import MarkovTextGenerator
|
||||||
|
from textgen_lstm import LSTMTextGenerator
|
||||||
|
|
||||||
# TODO: Reenable and extend scraper
|
# TODO: Reenable and extend scraper
|
||||||
# from models import Models
|
# from models import Models
|
||||||
@ -57,17 +58,27 @@ class HeidiClient(discord.Client):
|
|||||||
# Textgen
|
# Textgen
|
||||||
self.textgen_models: dict[str, textgen] = {
|
self.textgen_models: dict[str, textgen] = {
|
||||||
# The name must correspond to the name of the training text file
|
# The name must correspond to the name of the training text file
|
||||||
"bibel": MarkovTextGenerator(3), # Prefix length of 3
|
"kommunistisches_manifest": LSTMTextGenerator(10),
|
||||||
"kommunistisches_manifest": MarkovTextGenerator(3),
|
# "musk": LSTMTextGenerator(10),
|
||||||
"musk": MarkovTextGenerator(3)
|
# "bibel": LSTMTextGenerator(10)
|
||||||
|
|
||||||
|
# "bibel": MarkovTextGenerator(3), # Prefix length of 3
|
||||||
|
# "kommunistisches_manifest": MarkovTextGenerator(3),
|
||||||
|
# "musk": MarkovTextGenerator(3)
|
||||||
}
|
}
|
||||||
|
|
||||||
for name, model in self.textgen_models.items():
|
for name, model in self.textgen_models.items():
|
||||||
model.init(name) # Loads the textfile
|
model.init(name) # Loads the textfile
|
||||||
if DOCKER:
|
|
||||||
|
if os.path.exists(f"weights/{name}_lstm_model.pt"):
|
||||||
model.load()
|
model.load()
|
||||||
else:
|
elif not DOCKER:
|
||||||
model.train()
|
model.train()
|
||||||
|
else:
|
||||||
|
print("Error: Can't load model", name)
|
||||||
|
|
||||||
|
print("Generating test sentence for", name)
|
||||||
|
self.textgen_models[name].generate_sentence()
|
||||||
|
|
||||||
# Synchronize commands to guilds
|
# Synchronize commands to guilds
|
||||||
async def setup_hook(self):
|
async def setup_hook(self):
|
||||||
|
Reference in New Issue
Block a user