21 Commits
v3.0 ... master

Author SHA1 Message Date
d04244221b Replace ENTRANCE.SOUND menu with dropdowns
All checks were successful
Build Heidi Docker image / build-docker (push) Successful in 14s
2023-12-09 23:01:24 +01:00
c2847de7dd Add instantbuttons command + make responses ephemeral
All checks were successful
Build Heidi Docker image / build-docker (push) Successful in 14s
/instantbuttons displays a soundboard via a button ui
2023-12-09 19:51:28 +01:00
08230eb3de Enforce heidi_spam channel for commands
All checks were successful
Build Heidi Docker image / build-docker (push) Successful in 14s
2023-12-09 18:44:16 +01:00
f2ddb4ab66 Only play entrance sound when other is present + reformat
All checks were successful
Build Heidi Docker image / build-docker (push) Successful in 14s
2023-12-09 18:04:24 +01:00
876232f674 Ignore user config file 2023-12-09 18:03:42 +01:00
d7c3a7c740 Allow sounds with different file extensions
All checks were successful
Build Heidi Docker image / build-docker (push) Successful in 14s
Before only .mkv files could be played, as the extension was hardcoded
2023-12-09 17:55:21 +01:00
bdcd5208a7 Untrack Heidi_User.conf 2023-12-09 17:54:56 +01:00
79fcf0142a Some more options for randomly selected answers 2023-12-09 17:48:27 +01:00
0f6cc12182 Delete orphaned code
All checks were successful
Build Heidi Docker image / build-docker (push) Successful in 27s
2023-12-09 17:36:53 +01:00
9b66061ee7 Reformat TODO comments
All checks were successful
Build Heidi Docker image / build-docker (push) Successful in 15s
2023-12-09 17:35:04 +01:00
c6608e4695 Remove rocm stuff from flake 2023-12-09 17:34:54 +01:00
3335009692 Fix SOUNDDIR being in the wrong file
Some checks failed
Build Heidi Docker image / build-docker (push) Failing after 17s
2023-12-09 17:28:09 +01:00
d7604b6604 Update flake.lock 2023-12-09 17:20:46 +01:00
2e493e404b Split Heidi into multiple parts
All checks were successful
Build Heidi Docker image / build-docker (push) Successful in 28s
2023-12-09 17:14:05 +01:00
16822e0212 Remove commented out code + add docstrings 2023-12-09 16:31:00 +01:00
9d78352ea5 Update handling of "None" 2023-12-09 15:48:19 +01:00
13b3e9910a Add sounds 2023-12-09 15:28:26 +01:00
1b89d2ef3b Update flake.lock 2023-12-09 15:28:20 +01:00
6debffbd77 Add Suiii sound 2023-11-27 20:50:58 +01:00
e08c1c0204 Add joko sounds 2023-11-26 12:18:13 +01:00
82f0387675 Add yakari sound 2023-11-26 00:59:04 +01:00
22 changed files with 526 additions and 941 deletions

1
.gitignore vendored
View File

@ -12,3 +12,4 @@ Pipfile.lock
/disabled_voicelines/
*.svg
.vscode
Heidi_User.conf

View File

@ -1,38 +0,0 @@
workflow: # for entire pipeline
rules:
- if: '$CI_COMMIT_REF_NAME == "master"' # only run on master...
changes: # ...and when these files have changed
- "*.py"
- "Dockerfile"
docker-build:
stage: build
image: docker:20 # provides the docker toolset (but without an active daemon)
services: # configure images that run during jobs linked to the image (above)
- docker:dind # dind build on docker and starts up the dockerdaemon (docker itself doesn't do that), which is needed to call docker build etc.
before_script:
- docker login -u $CI_REGISTRY_USER -p "$CI_REGISTRY_PASSWORD" $CI_REGISTRY
script:
- docker pull $CI_REGISTRY_IMAGE:latest || true # latest image for cache (not failing if image is not found)
- >
docker build
--pull
--cache-from $CI_REGISTRY_IMAGE:latest
--label "org.opencontainers.image.title=$CI_PROJECT_TITLE"
--label "org.opencontainers.image.url=$CI_PROJECT_URL"
--label "org.opencontainers.image.created=$CI_JOB_STARTED_AT"
--label "org.opencontainers.image.revision=$CI_COMMIT_SHA"
--label "org.opencontainers.image.version=$CI_COMMIT_REF_NAME"
--tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA
.
- docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA $CI_REGISTRY_IMAGE:latest
- docker push $CI_REGISTRY_IMAGE:latest
docker-deploy:
stage: deploy
image: alpine:3.15
needs: ["docker-build"]
script:
- chmod og= $ID_RSA
- apk update && apk add openssh-client
- ssh -i $ID_RSA -o StrictHostKeyChecking=no $SERVER_USER@$SERVER_IP "/home/christoph/$CI_PROJECT_TITLE/launch.sh"

606
bot.py
View File

@ -1,221 +1,21 @@
# Example: https://github.com/Rapptz/discord.py/blob/master/examples/app_commands/basic.py
import os, re, random, logging, asyncio, discord, configparser
from discord import app_commands
from ast import Call
import random, logging
from discord import DMChannel
from discord.app_commands import Choice
from functools import reduce
from dotenv import load_dotenv
from typing import Dict, List, Optional, Union
# TODO: Reenable + extend textgen
# from textgen import textgen
# from textgen_markov import MarkovTextGenerator
# from textgen_lstm import LSTMTextGenerator
# TODO: Reenable + extend scraper
# from models import Models
# We're fancy today
from typing import Awaitable, Dict, List, Optional, Union, Callable, Any
from rich.traceback import install
from heidi_client import *
# Install rich traceback
install(show_locals=True)
load_dotenv()
# ================================================================================================ #
# ================================================================================================ #
# NOTE: Always set this correctly:
DOCKER = os.getenv("DOCKER") == "True"
# ================================================================================================ #
# ================================================================================================ #
# @todo yt-dlp music support
# TODO: Only post in heidi-spam channel
# TODO: yt-dlp music support
# TODO: Somehow upload voicelines more easily (from discord voice message?)
# IDs of the servers Heidi is used on
LINUS_GUILD = discord.Object(id=431154792308408340)
TEST_GUILD = discord.Object(id=821511861178204161)
CONFIGPATH = "/config" if DOCKER else "."
USERCONFIGNAME = "Heidi_User.conf"
class HeidiClient(discord.Client):
def __init__(self, *, intents: discord.Intents):
super().__init__(status="Nur eine kann GNTM werden!", intents=intents)
# Separate object that keeps all application command state
self.tree = app_commands.CommandTree(self)
# Handle persistent user configuration
self.user_config = configparser.ConfigParser()
if not os.path.exists(f"{CONFIGPATH}/{USERCONFIGNAME}"):
os.mknod(f"{CONFIGPATH}/{USERCONFIGNAME}")
self.user_config.read(f"{CONFIGPATH}/{USERCONFIGNAME}")
self.update_to_default_user_config()
self.print_user_config()
# self.models = Models() # scraped model list
# automatic actions on all messages
# on_message_triggers is a map with tuples of two functions: (predicate, action)
# the predicate receives the message as argument
# if the predicate is true the action is performed
self.on_message_triggers = {
# lambda m: m.author.nick.lower() in self.models.get_in_names(): self.autoreact_to_girls,
lambda m: "jeremy" in m.author.nick.lower(): self._autoreact_to_jeremy,
lambda m: "kardashian" in m.author.nick.lower()
or "jenner" in m.author.nick.lower(): self._autoreact_to_kardashian,
}
# automatic actions on voice state changes
# on_voice_state_triggers is a map with tuples of two functions: (predicate, action)
# the predicate receives the member, before- and after-state as arguments
# if the predicate is true, the action is performed
self.on_voice_state_triggers = {
lambda m, b, a: b.channel != a.channel
and a.channel != None
and isinstance(a.channel, discord.VoiceChannel): self._play_entrance_sound,
}
# Textgen
# self.textgen_models: dict[str, textgen] = {
# # The name must correspond to the name of the training text file
# "kommunistisches_manifest": LSTMTextGenerator(10),
# "musk": LSTMTextGenerator(10),
# "bibel": LSTMTextGenerator(10)
# "bibel": MarkovTextGenerator(3), # Prefix length of 3
# "kommunistisches_manifest": MarkovTextGenerator(3),
# "musk": MarkovTextGenerator(3)
# }
# for name, model in self.textgen_models.items():
# model.init(name) # Loads the textfile
# if os.path.exists(f"weights/{name}_lstm_model.pt"):
# model.load()
# elif not DOCKER:
# model.train()
# else:
# print("Error: Can't load model", name)
# print("Generating test sentence for", name)
# self.textgen_models[name].generate_sentence()
# Synchronize commands to guilds
async def setup_hook(self):
self.tree.copy_global_to(guild=LINUS_GUILD)
await self.tree.sync(guild=LINUS_GUILD)
self.tree.copy_global_to(guild=TEST_GUILD)
await self.tree.sync(guild=TEST_GUILD)
def update_to_default_user_config(self):
"""
Adds config keys to the config, if they don't exist yet.
"""
user_config_sections = ["ENTRANCE.SOUND"]
for section in user_config_sections:
if section not in self.user_config:
print(f"Adding section {section} to {CONFIGPATH}/{USERCONFIGNAME}")
self.user_config[section] = dict()
self.write_user_config()
def print_user_config(self):
print("Read persistent configuration:\n")
for section in self.user_config.sections():
print(f"[{section}]")
for key in self.user_config[section]:
print(f"{key}={self.user_config[section][key]}")
print("")
def write_user_config(self):
if not os.path.exists(f"{CONFIGPATH}/{USERCONFIGNAME}"):
print(f"Error: {CONFIGPATH}/{USERCONFIGNAME} doesn't exist!")
return
print(f"Writing {CONFIGPATH}/{USERCONFIGNAME}")
with open(f"{CONFIGPATH}/{USERCONFIGNAME}", "w") as file:
self.user_config.write(file)
# Commands -----------------------------------------------------------------------------------
# async def list_models_in(self, message):
# """
# wer ist dabei?
# """
# await message.channel.send("\n".join(self.models.get_in_names()))
# async def list_models_out(self, message):
# """
# wer ist raus? (Liste der Keks welche ge*ickt wurden)
# """
# await message.channel.send("\n".join(self.models.get_out_names()))
# async def show_model_picture(self, message):
# """
# gib Bild von <Name>
# """
# name = message.content.split()[-1]
# picture = discord.Embed()
# picture.set_image(url=self.models.get_image(name))
# picture.set_footer(text=name)
# await message.channel.send(embed=picture)
# Automatic Actions --------------------------------------------------------------------------
# @staticmethod
# async def autoreact_to_girls(message):
# """
# ❤ aktives Model
# """
# await message.add_reaction("❤")
@staticmethod
async def _autoreact_to_jeremy(message: discord.Message):
"""
🧀 Jeremy
"""
await message.add_reaction("🧀")
@staticmethod
async def _autoreact_to_kardashian(message: discord.Message):
"""
💄 Kardashian
"""
await message.add_reaction("💄")
async def _play_entrance_sound(
self,
member: discord.Member,
before: discord.VoiceState,
after: discord.VoiceState,
):
soundpath: Union[str, None] = self.user_config["ENTRANCE.SOUND"].get(
member.name, None
)
if soundpath == None:
print(f"User {member.name} has not set an entrance sound")
return
board, sound = soundpath.split("/")
# Wait a bit to not have simultaneous joins
await asyncio.sleep(1)
await play_voice_line_for_member(None, member, board, sound)
# ------------------------------------------------------------------------------------------------
# Log to file
handler = logging.FileHandler(filename="discord.log", encoding="utf-8", mode="w")
@ -225,23 +25,30 @@ intents.members = True # Allow to react to member join/leave etc
intents.message_content = True # Allow to read message content from arbitrary messages
intents.voice_states = True # Allow to process on_voice_state_update
# Setup our client
# Set up our client
client = HeidiClient(intents=intents)
# Events -----------------------------------------------------------------------------------------
# NOTE: I defined the events outside of the Client class, don't know if I like it or not...
# NOTE: I defined the events outside the Client class, don't know if I like it or not...
@client.event
async def on_ready():
if client.user != None:
async def on_ready() -> None:
"""
This event triggers when the Heidi client has finished connecting.
"""
if client.user is not None:
print(f"{client.user} (id: {client.user.id}) has connected to Discord!")
else:
print("client.user is None!")
@client.event
async def on_message(message: discord.Message):
async def on_message(message: Message) -> None:
"""
This event triggers when a message is sent in any text channel.
"""
# Skip Heidis own messages
if message.author == client.user:
return
@ -257,8 +64,11 @@ async def on_message(message: discord.Message):
@client.event
async def on_voice_state_update(
member: discord.Member, before: discord.VoiceState, after: discord.VoiceState
):
member: Member, before: VoiceState, after: VoiceState
) -> None:
"""
This event triggers when a member joins/changes/leaves a voice channel or mutes/unmutes.
"""
# Skip Heidis own voice state updates (e.g. on /say)
if member._user == client.user:
return
@ -267,7 +77,7 @@ async def on_voice_state_update(
# python iterates over the keys of a map
for predicate in client.on_voice_state_triggers:
if predicate(member, before, after):
action = client.on_voice_state_triggers[predicate]
action: Callable = client.on_voice_state_triggers[predicate]
print(f"on_voice_state_update: calling {action.__name__}")
await action(member, before, after)
@ -275,9 +85,65 @@ async def on_voice_state_update(
# Config Commands --------------------------------------------------------------------------------
class EntranceSoundSoundSelect(discord.ui.Select):
def __init__(self, board: str, on_sound_select_callback):
self.board = board
self.on_sound_select_callback = on_sound_select_callback
options: List[discord.SelectOption] = [
discord.SelectOption(label=sound.split(".")[0], value=sound)
for sound in os.listdir(f"{SOUNDDIR}/{board}")
]
super().__init__(
placeholder="Select Sound", min_values=1, max_values=1, options=options
)
async def callback(self, interaction: Interaction):
await self.on_sound_select_callback(interaction, self.board, self.values[0])
class EntranceSoundSoundView(discord.ui.View):
def __init__(self, board: str, on_sound_select_callback):
super().__init__(timeout=600)
self.add_item(EntranceSoundSoundSelect(board, on_sound_select_callback))
class EntranceSoundBoardSelect(discord.ui.Select):
def __init__(self, on_sound_select_callback):
self.on_sound_select_callback = on_sound_select_callback
options: List[discord.SelectOption] = [
discord.SelectOption(label=board, value=board)
for board in os.listdir(f"{SOUNDDIR}")
]
super().__init__(
placeholder="Select Board", min_values=1, max_values=1, options=options
)
async def callback(self, interaction: Interaction):
await interaction.response.send_message(
f"Welchen sound willst du?",
view=EntranceSoundSoundView(self.values[0], self.on_sound_select_callback),
ephemeral=True,
)
class EntranceSoundBoardView(discord.ui.View):
def __init__(self, on_sound_select_callback):
super().__init__(timeout=600)
self.add_item(EntranceSoundBoardSelect(on_sound_select_callback))
async def user_config_key_autocomplete(
interaction: discord.Interaction, current: str
interaction: Interaction, current: str
) -> List[Choice[str]]:
"""
Suggest a value from the user config keys (each .conf section is a key).
"""
return [
Choice(name=key, value=key)
for key in client.user_config.sections()
@ -285,45 +151,6 @@ async def user_config_key_autocomplete(
]
async def user_config_value_autocomplete(
interaction: discord.Interaction, current: str
) -> List[Choice[str]]:
"""
Calls an autocomplete function depending on the entered config_key.
"""
autocompleters = {"ENTRANCE.SOUND": user_entrance_sound_autocomplete}
autocompleter = autocompleters[interaction.namespace.option]
print(f"config_value_autocomplete: calling {autocompleter.__name__}")
return autocompleter(interaction, current)
def user_entrance_sound_autocomplete(
interaction: discord.Interaction, current: str
) -> List[Choice[str]]:
"""
Generates autocomplete options for the ENTRANCE.SOUND config key.
"""
boards: List[str] = os.listdir(SOUNDDIR)
all_sounds: Dict[str, List[str]] = {
board: list(map(lambda x: x.split(".")[0], os.listdir(f"{SOUNDDIR}/{board}/")))
for board in boards
} # These are all sounds, organized per board
completions: List[Choice[str]] = []
for (
board,
board_sounds,
) in all_sounds.items(): # Iterate over all sounds, organized per board
for sound in board_sounds: # Iterate over board specific sounds
soundpath = f"{board}/{sound}"
if soundpath.lower().startswith(current.lower()):
completions += [Choice(name=soundpath, value=soundpath)]
return completions
@client.tree.command(
name="userconfig",
description="User-spezifische Heidi-Einstellungen (Heidi merkt sie sich in ihrem riesigen Gehirn).",
@ -331,48 +158,54 @@ def user_entrance_sound_autocomplete(
@app_commands.rename(config_key="option")
@app_commands.describe(config_key="Die Option, welche du ändern willst.")
@app_commands.autocomplete(config_key=user_config_key_autocomplete)
@app_commands.rename(config_value="wert")
@app_commands.describe(
config_value="Der Wert, auf welche die Option gesetzt werden soll."
)
@app_commands.autocomplete(config_value=user_config_value_autocomplete)
async def user_config(
interaction: discord.Interaction, config_key: str, config_value: str
):
@enforce_channel(HEIDI_SPAM_ID)
async def user_config(interaction: Interaction, config_key: str) -> None:
"""
Set a user config value for the calling user.
"""
# Only Members can set settings
if not isinstance(interaction.user, discord.Member):
if not isinstance(interaction.user, Member):
print("User not a member")
await interaction.response.send_message("Heidi sagt: Komm in die Gruppe!")
await interaction.response.send_message(
"Heidi sagt: Komm in die Gruppe!", ephemeral=True
)
return
member: discord.Member = interaction.user
member: Member = interaction.user
client.user_config[config_key][member.name] = config_value
client.write_user_config()
async def on_sound_select_callback(interaction, board: str, sound: str):
"""
This function is called, when an EntrySoundSoundSelect option is selected.
"""
client.user_config[config_key][member.name] = f"{board}/{sound}"
client.write_user_config()
await interaction.response.send_message(
f"Ok, ich schreibe {member.name}={board}/{sound} in mein fettes Gehirn!",
ephemeral=True,
)
# Views for different user config options are defined here
views = {"ENTRANCE.SOUND": (EntranceSoundBoardView, on_sound_select_callback)}
view, select_callback = views[config_key]
await interaction.response.send_message(
f"Ok, ich schreibe {member.name}={config_value} in mein fettes Gehirn!"
f"Aus welchem Soundboard soll dein sound sein?",
view=view(select_callback),
ephemeral=True,
)
# Commands ---------------------------------------------------------------------------------------
@client.tree.command(
name="giblinkbruder",
description="Heidi hilft mit dem Link zu deiner Lieblingsshow im Qualitätsfernsehen.",
)
async def show_link(interaction: discord.Interaction):
link_pro7 = "https://www.prosieben.de/tv/germanys-next-topmodel/livestream"
link_joyn = "https://www.joyn.de/serien/germanys-next-topmodel"
await interaction.response.send_message(
f"ProSieben: {link_pro7}\nJoyn: {link_joyn}"
)
@client.tree.command(name="heidi", description="Heidi!")
async def heidi_exclaim(interaction: discord.Interaction):
@enforce_channel(HEIDI_SPAM_ID)
async def heidi_exclaim(interaction: Interaction) -> None:
"""
Print a random Heidi quote.
"""
messages = [
"Die sind doch fast 18!",
"Heidi!",
@ -381,6 +214,10 @@ async def heidi_exclaim(interaction: discord.Interaction):
"Warum denn so schüchtern?",
"Im TV ist das legal!",
"Das Stroh ist nur fürs Shooting!",
"Jetzt sei doch mal sexy!",
"Stell dich nicht so an!",
"Models müssen da halt durch!",
"Heul doch nicht!",
]
await interaction.response.send_message(random.choice(messages))
@ -388,16 +225,27 @@ async def heidi_exclaim(interaction: discord.Interaction):
@client.tree.command(name="miesmuschel", description="Was denkt Heidi?")
@app_commands.rename(question="frage")
@app_commands.describe(question="Heidi wird es beantworten!")
async def magic_shell(interaction: discord.Interaction, question: str):
@enforce_channel(HEIDI_SPAM_ID)
async def magic_shell(interaction: Interaction, question: str) -> None:
"""
Answer a yes/no question.
"""
# Should be equal amounts of yes/no answers, to have a 50/50 chance.
choices = [
"Ja!",
"Jo.",
"Jo",
"Total!",
"Natürlich.",
"Natürlich",
"Klaro Karo",
"Offensichtlich Sherlock",
"Tom sagt Ja",
"Nein!",
"Nö.",
"Nä.",
"Niemals!",
"Nur über meine Leiche du Hurensohn!",
"In deinen Träumen.",
"Tom sagt Nein",
]
question = question.strip()
question_mark = "" if question[-1] == "?" else "?"
@ -406,58 +254,32 @@ async def magic_shell(interaction: discord.Interaction, question: str):
)
# TODO: Allow , separated varargs, need to parse manually as slash commands don't support varargs
# @todo Allow , separated varargs, need to parse manually as slash commands don't support varargs
@client.tree.command(name="wähle", description="Heidi trifft die Wahl!")
@app_commands.rename(option_a="entweder")
@app_commands.describe(option_a="Ist es vielleicht dies?")
@app_commands.rename(option_b="oder")
@app_commands.describe(option_b="Oder doch eher das?")
async def choose(interaction: discord.Interaction, option_a: str, option_b: str):
@enforce_channel(HEIDI_SPAM_ID)
async def choose(interaction: Interaction, option_a: str, option_b: str) -> None:
"""
Select an answer from two options.
"""
options = [option_a.strip(), option_b.strip()]
await interaction.response.send_message(
f"{options[0]} oder {options[1]}?\nHeidi sagt: {random.choice(options)}"
)
# TextGen ----------------------------------------------------------------------------------------
# async def quote_model_autocomplete(interaction: discord.Interaction, current: str) -> list[Choice[str]]:
# models = client.textgen_models.keys()
# return [Choice(name=model, value=model) for model in models]
# @client.tree.command(name="zitat", description="Heidi zitiert!")
# @app_commands.rename(quote_model = "style")
# @app_commands.describe(quote_model = "Woraus soll Heidi zitieren?")
# @app_commands.autocomplete(quote_model = quote_model_autocomplete)
# async def quote(interaction: discord.Interaction, quote_model: str):
# generated_quote = client.textgen_models[quote_model].generate_sentence()
# joined_quote = " ".join(generated_quote)
# await interaction.response.send_message(f"Heidi zitiert: \"{joined_quote}\"")
# @client.tree.command(name="vervollständige", description="Heidi beendet den Satz!")
# @app_commands.rename(prompt = "satzanfang")
# @app_commands.describe(prompt = "Der Satzanfang wird vervollständigt.")
# @app_commands.rename(quote_model = "style")
# @app_commands.describe(quote_model = "Woraus soll Heidi vervollständigen?")
# @app_commands.autocomplete(quote_model = quote_model_autocomplete)
# async def complete(interaction: discord.Interaction, prompt: str, quote_model: str):
# prompt = re.sub(r"[^a-zäöüß'.,]+", " ", prompt.lower()) # only keep valid chars
# generated_quote = client.textgen_models[quote_model].complete_sentence(prompt.split())
# joined_quote = " ".join(generated_quote)
# await interaction.response.send_message(f"Heidi sagt: \"{joined_quote}\"")
# Sounds -----------------------------------------------------------------------------------------
SOUNDDIR: str = "/sounds" if DOCKER else "./heidi-sounds"
# Example: https://discordpy.readthedocs.io/en/latest/interactions/api.html?highlight=autocomplete#discord.app_commands.autocomplete
async def board_autocomplete(
interaction: discord.Interaction, current: str
interaction: Interaction, current: str
) -> List[Choice[str]]:
"""
Suggest a sound board.
"""
boards: List[str] = os.listdir(SOUNDDIR)
return [
@ -468,15 +290,16 @@ async def board_autocomplete(
async def sound_autocomplete(
interaction: discord.Interaction, current: str
interaction: Interaction, current: str
) -> List[Choice[str]]:
"""
Suggest a sound from an already selected board.
"""
board: str = interaction.namespace.board
sounds: List[str] = list(
map(lambda x: x.split(".")[0], os.listdir(f"{SOUNDDIR}/{board}/"))
)
sounds: List[str] = os.listdir(f"{SOUNDDIR}/{board}/")
return [
Choice(name=sound, value=sound)
Choice(name=sound.split(".")[0], value=sound)
for sound in sounds
if sound.lower().startswith(current.lower())
]
@ -488,32 +311,86 @@ async def sound_autocomplete(
@app_commands.describe(sound="Was soll Heidi sagen?")
@app_commands.autocomplete(board=board_autocomplete)
@app_commands.autocomplete(sound=sound_autocomplete)
async def say_voiceline(interaction: discord.Interaction, board: str, sound: str):
@enforce_channel(HEIDI_SPAM_ID)
async def say_voiceline(interaction: Interaction, board: str, sound: str) -> None:
"""
Play a voiceline in the calling member's current voice channel.
"""
# Only Members can access voice channels
if not isinstance(interaction.user, discord.Member):
if not isinstance(interaction.user, Member):
print("User not a member")
await interaction.response.send_message("Heidi sagt: Komm in die Gruppe!")
await interaction.response.send_message(
"Heidi sagt: Komm in die Gruppe!", ephemeral=True
)
return
member: discord.Member = interaction.user
member: Member = interaction.user
await play_voice_line_for_member(interaction, member, board, sound)
class InstantButton(discord.ui.Button):
def __init__(self, label: str, board: str, sound: str):
super().__init__(style=discord.ButtonStyle.red, label=label)
self.board = board
self.sound = sound
async def callback(self, interaction: Interaction):
"""
Handle a press of the button.
"""
if not isinstance(interaction.user, Member):
await interaction.response.send_message(
"Heidi mag keine discord.User, nur discord.Member!", ephemeral=True
)
return
await play_voice_line_for_member(
interaction, interaction.user, self.board, self.sound
)
class InstantButtonsView(discord.ui.View):
def __init__(self, board: str, timeout=None):
super().__init__(timeout=timeout)
sounds = os.listdir(f"{SOUNDDIR}/{board}")
for sound in sounds:
self.add_item(InstantButton(sound.split(".")[0], board, sound))
@client.tree.command(
name="instantbuttons", description="Heidi malt Knöpfe für Sounds in den Chat."
)
@app_commands.describe(board="Welches Soundboard soll knöpfe bekommen?")
@app_commands.autocomplete(board=board_autocomplete)
@enforce_channel(HEIDI_SPAM_ID)
async def soundboard_buttons(interaction: Interaction, board: str) -> None:
await interaction.response.send_message(
f"Soundboard: {board.capitalize()}", view=InstantButtonsView(board)
)
# Contextmenu ------------------------------------------------------------------------------------
# Callable on members
@client.tree.context_menu(name="beleidigen")
async def insult(
interaction: discord.Interaction, member: discord.Member
): # with message: discord.Message this can be called on a message
interaction: Interaction, member: Member
) -> None: # with message: discord.Message this can be called on a message
"""
Send an insult to a member via direct message.
"""
if not member.dm_channel:
await member.create_dm()
if not member.dm_channel:
print("Error creating DMChannel!")
await interaction.response.send_message("Heidi sagt: Gib mal DM Nummer süße*r!")
await interaction.response.send_message(
"Heidi sagt: Gib mal DM Nummer süße*r!", ephemeral=True
)
return
insults = [
@ -534,67 +411,10 @@ async def insult(
await member.dm_channel.send(random.choice(insults))
await interaction.response.send_message(
"Anzeige ist raus!"
"Anzeige ist raus!", ephemeral=True
) # with ephemeral = True only the caller can see the answer
# Helpers ----------------------------------------------------------------------------------------
async def play_voice_line(
interaction: Union[discord.Interaction, None],
voice_channel: discord.VoiceChannel,
board: str,
sound: str,
):
try:
open(f"{SOUNDDIR}/{board}/{sound}.mkv")
except IOError:
print("Error: Invalid soundfile!")
if interaction != None:
await interaction.response.send_message(
f'Heidi sagt: "{board}/{sound}" kanninich finden bruder'
)
return
if interaction != None:
await interaction.response.send_message(f'Heidi sagt: "{board}/{sound}"')
audio_source = discord.FFmpegPCMAudio(
f"{SOUNDDIR}/{board}/{sound}.mkv"
) # only works from docker
voice_client = await voice_channel.connect()
voice_client.play(audio_source)
while voice_client.is_playing():
await asyncio.sleep(1)
await voice_client.disconnect()
async def play_voice_line_for_member(
interaction: Union[discord.Interaction, None],
member: discord.Member,
board: str,
sound: str,
):
# Member needs to be in voice channel to hear audio (Heidi needs to know the channel to join)
if (
member == None
or member.voice == None
or member.voice.channel == None
or not isinstance(member.voice.channel, discord.VoiceChannel)
):
print("User not in (valid) voice channel!")
if interaction != None:
await interaction.response.send_message("Heidi sagt: Komm in den Channel!")
return
voice_channel: discord.VoiceChannel = member.voice.channel
await play_voice_line(interaction, voice_channel, board, sound)
# ------------------------------------------------------------------------------------------------

18
flake.lock generated
View File

@ -6,11 +6,11 @@
"systems": "systems"
},
"locked": {
"lastModified": 1700815693,
"narHash": "sha256-JtKZEQUzosrCwDsLgm+g6aqbP1aseUl1334OShEAS3s=",
"lastModified": 1701787589,
"narHash": "sha256-ce+oQR4Zq9VOsLoh9bZT8Ip9PaMLcjjBUHVPzW5d7Cw=",
"owner": "numtide",
"repo": "devshell",
"rev": "7ad1c417c87e98e56dcef7ecd0e0a2f2e5669d51",
"rev": "44ddedcbcfc2d52a76b64fb6122f209881bd3e1e",
"type": "github"
},
"original": {
@ -24,11 +24,11 @@
"systems": "systems_2"
},
"locked": {
"lastModified": 1694529238,
"narHash": "sha256-zsNZZGTGnMOf9YpHKJqMSsa0dXbfmxeoJ7xHlrt+xmY=",
"lastModified": 1701680307,
"narHash": "sha256-kAuep2h5ajznlPMD9rnQyffWG8EM/C73lejGofXvdM8=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "ff7b65b44d01cf9ba6a71320833626af21126384",
"rev": "4022d587cbbfd70fe950c1e2083a02621806a725",
"type": "github"
},
"original": {
@ -55,11 +55,11 @@
},
"nixpkgs_2": {
"locked": {
"lastModified": 1700856099,
"narHash": "sha256-RnEA7iJ36Ay9jI0WwP+/y4zjEhmeN6Cjs9VOFBH7eVQ=",
"lastModified": 1701693815,
"narHash": "sha256-7BkrXykVWfkn6+c1EhFA3ko4MLi3gVG0p9G96PNnKTM=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "0bd59c54ef06bc34eca01e37d689f5e46b3fe2f1",
"rev": "09ec6a0881e1a36c29d67497693a67a16f4da573",
"type": "github"
},
"original": {

View File

@ -14,27 +14,6 @@
overlays = [ devshell.overlays.default ];
};
# TODO: Originally it was nixpkgs.fetchurl but that didn't work, pkgs.fetchurl did...
# Determine the difference between nixpkgs and pkgs
# Taken from: https://github.com/gbtb/nix-stable-diffusion/blob/master/flake.nix
# Overlay: https://nixos.wiki/wiki/Overlays
# FetchURL: https://ryantm.github.io/nixpkgs/builders/fetchers/
torch-rocm = pkgs.hiPrio (pkgs.python310Packages.torch-bin.overrideAttrs (old: {
src = pkgs.fetchurl {
name = "torch-1.12.1+rocm5.1.1-cp310-cp310-linux_x86_64.whl";
url = "https://download.pytorch.org/whl/rocm5.1.1/torch-1.12.1%2Brocm5.1.1-cp310-cp310-linux_x86_64.whl";
hash = "sha256-kNShDx88BZjRQhWgnsaJAT8hXnStVMU1ugPNMEJcgnA=";
};
}));
torchvision-rocm = pkgs.hiPrio (pkgs.python310Packages.torchvision-bin.overrideAttrs (old: {
src = pkgs.fetchurl {
name = "torchvision-0.13.1+rocm5.1.1-cp310-cp310-linux_x86_64.whl";
url = "https://download.pytorch.org/whl/rocm5.1.1/torchvision-0.13.1%2Brocm5.1.1-cp310-cp310-linux_x86_64.whl";
hash = "sha256-mYk4+XNXU6rjpgWfKUDq+5fH/HNPQ5wkEtAgJUDN/Jg=";
};
}));
myPython = pkgs.python311.withPackages (p: with p; [
# Basic
rich

BIN
heidi-sounds/basic/Suiii.mkv (Stored with Git LFS) Normal file

Binary file not shown.

BIN
heidi-sounds/drache/Hagebuddne.mkv (Stored with Git LFS) Normal file

Binary file not shown.

BIN
heidi-sounds/henri/Ich bin der Pablo.mkv (Stored with Git LFS) Normal file

Binary file not shown.

BIN
heidi-sounds/henri/Kann ich behilflich sein.mkv (Stored with Git LFS) Normal file

Binary file not shown.

BIN
heidi-sounds/henri/Yakari.mkv (Stored with Git LFS) Normal file

Binary file not shown.

BIN
heidi-sounds/tit/Ab in Knast.mkv (Stored with Git LFS) Normal file

Binary file not shown.

BIN
heidi-sounds/tit/Bitconnect.mkv (Stored with Git LFS) Normal file

Binary file not shown.

BIN
heidi-sounds/vinz/VINZENT.mkv (Stored with Git LFS) Normal file

Binary file not shown.

145
heidi_client.py Normal file
View File

@ -0,0 +1,145 @@
import configparser
from discord import app_commands, Message, VoiceState
from heidi_constants import *
from heidi_helpers import *
class HeidiClient(discord.Client):
def __init__(self, *, intents: discord.Intents):
super().__init__(status="Nur eine kann GNTM werden!", intents=intents)
# Separate object that keeps all application command state
self.tree = app_commands.CommandTree(self)
# Handle persistent user configuration
self.user_config = configparser.ConfigParser()
if not os.path.exists(f"{CONFIGPATH}/{USERCONFIGNAME}"):
open(f"{CONFIGPATH}/{USERCONFIGNAME}", "x")
self.user_config.read(f"{CONFIGPATH}/{USERCONFIGNAME}")
self.update_to_default_user_config()
self.print_user_config()
# automatic actions on all messages
# on_message_triggers is a map with tuples of two functions: (predicate, action)
# the predicate receives the message as argument
# if the predicate is true the action is performed
self.on_message_triggers = {
# lambda m: m.author.nick.lower() in self.models.get_in_names(): self.autoreact_to_girls,
lambda m: "jeremy" in m.author.nick.lower(): self._autoreact_to_jeremy,
lambda m: "kardashian" in m.author.nick.lower()
or "jenner" in m.author.nick.lower(): self._autoreact_to_kardashian,
}
# automatic actions on voice state changes
# on_voice_state_triggers is a map with tuples of two functions: (predicate, action)
# the predicate receives the member, before- and after-state as arguments
# if the predicate is true, the action is performed
self.on_voice_state_triggers = {
lambda m, b, a: b.channel != a.channel
and a.channel is not None
and isinstance(a.channel, VoiceChannel): self._play_entrance_sound,
}
# Synchronize commands to guilds
async def setup_hook(self):
self.tree.copy_global_to(guild=LINUS_GUILD)
await self.tree.sync(guild=LINUS_GUILD)
self.tree.copy_global_to(guild=TEST_GUILD)
await self.tree.sync(guild=TEST_GUILD)
def update_to_default_user_config(self) -> None:
"""
Adds config keys to the config, if they don't exist yet.
This writes the user config file.
"""
user_config_sections = ["ENTRANCE.SOUND"]
for section in user_config_sections:
if section not in self.user_config:
print(f"Adding section {section} to {CONFIGPATH}/{USERCONFIGNAME}")
self.user_config[section] = dict()
self.write_user_config()
def print_user_config(self) -> None:
"""
Print the current user config from memory.
This does not read the user config file.
"""
print("Heidi User Config:\n")
for section in self.user_config.sections():
print(f"[{section}]")
for key in self.user_config[section]:
print(f"{key}={self.user_config[section][key]}")
print("")
def write_user_config(self) -> None:
"""
Write the current configuration to disk.
"""
if not os.path.exists(f"{CONFIGPATH}/{USERCONFIGNAME}"):
print(f"Error: {CONFIGPATH}/{USERCONFIGNAME} doesn't exist!")
return
print(f"Writing {CONFIGPATH}/{USERCONFIGNAME}")
with open(f"{CONFIGPATH}/{USERCONFIGNAME}", "w") as file:
self.user_config.write(file)
# Automatic Actions ------------------------------------------------------------------------------
@staticmethod
async def _autoreact_to_jeremy(message: Message) -> None:
"""
🧀 Jeremy.
This function is set in on_message_triggers and triggered by the on_message event.
"""
await message.add_reaction("🧀")
@staticmethod
async def _autoreact_to_kardashian(message: Message) -> None:
"""
💄 Kardashian.
This function is set in on_message_triggers and triggered by the on_message event.
"""
await message.add_reaction("💄")
async def _play_entrance_sound(
self,
member: Member,
before: VoiceState,
after: VoiceState,
) -> None:
"""
Play a sound when a member joins a voice channel (and another member is present).
This function is set in on_voice_state_triggers and triggered by the on_voice_state_update event.
"""
# Don't play anything when no other users are present
if (
member is not None
and member.voice is not None
and member.voice.channel is not None
and len(member.voice.channel.members) <= 1
):
print("Not playing entrance sound, as no other members are present")
return
soundpath: Union[str, None] = self.user_config["ENTRANCE.SOUND"].get(
member.name, None
)
if soundpath is None:
print(f"User {member.name} has not set an entrance sound")
return
board, sound = soundpath.split("/")
# Wait a bit to not have simultaneous joins
await asyncio.sleep(1)
await play_voice_line_for_member(None, member, board, sound)

29
heidi_constants.py Normal file
View File

@ -0,0 +1,29 @@
import os
import discord
from dotenv import load_dotenv
# This is run when this file is imported
load_dotenv()
print("Debug: Importing heidi_constants.py")
# ================================================================================================ #
# ================================================================================================ #
# NOTE: Always set this correctly:
DOCKER = os.getenv("DOCKER") == "True"
# ================================================================================================ #
# ================================================================================================ #
# Constants
CONFIGPATH = "/config" if DOCKER else "."
USERCONFIGNAME = "Heidi_User.conf"
SOUNDDIR: str = "/sounds" if DOCKER else "./heidi-sounds"
# IDs of the servers Heidi is used on
LINUS_GUILD = discord.Object(id=431154792308408340)
TEST_GUILD = discord.Object(id=821511861178204161)
# Channel IDs
HEIDI_SPAM_ID = 822223476101742682

105
heidi_helpers.py Normal file
View File

@ -0,0 +1,105 @@
import asyncio
import functools
from typing import Union
import discord
from discord import Interaction, VoiceChannel, Member
from heidi_constants import *
print("Debug: Importing heidi_helpers.py")
# Checks -----------------------------------------------------------------------------------------
# 1. @enforce_channel(ID) is added to a function, which evaluates to decorate with the channel_id in its closure
# 2. The function is passed to decorate(function),
def enforce_channel(channel_id):
"""
Only run a function if called from the correct channel.
"""
def decorate(function):
@functools.wraps(function)
async def wrapped(*args, **kwargs):
"""
Sends an interaction response if the interaction is not triggered from the heidi_spam channel.
"""
interaction: Interaction = args[0]
# Do not call the decorated function if the channel_id doesn't match
if not interaction.channel_id == channel_id:
await interaction.response.send_message("Heidi sagt: Geh in heidi_spam du dulli", ephemeral=True)
return
await function(*args, **kwargs)
return wrapped
return decorate
# Sounds -----------------------------------------------------------------------------------------
# @todo Normalize volume when playing
async def play_voice_line(
interaction: Union[Interaction, None],
voice_channel: VoiceChannel,
board: str,
sound: str,
) -> None:
"""
Play a voice line in the specified channel.
"""
try:
open(f"{SOUNDDIR}/{board}/{sound}")
except IOError:
print(f"Error: Invalid soundfile {SOUNDDIR}/{board}/{sound}!")
if interaction is not None:
await interaction.response.send_message(
f'Heidi sagt: "{board}/{sound}" kanninich finden bruder',
ephemeral=True
)
return
if interaction is not None:
await interaction.response.send_message(f'Heidi sagt: "{board}/{sound}"', ephemeral=True)
audio_source = discord.FFmpegPCMAudio(
f"{SOUNDDIR}/{board}/{sound}"
) # only works from docker
voice_client = await voice_channel.connect()
voice_client.play(audio_source)
while voice_client.is_playing():
await asyncio.sleep(1)
await voice_client.disconnect()
async def play_voice_line_for_member(
interaction: Union[Interaction, None],
member: Member,
board: str,
sound: str,
) -> None:
"""
Play a voice line in the member's current channel.
"""
# Member needs to be in voice channel to hear audio (Heidi needs to know the channel to join)
if (
member is None
or member.voice is None
or member.voice.channel is None
or not isinstance(member.voice.channel, VoiceChannel)
):
print("User not in (valid) voice channel!")
if interaction is not None:
await interaction.response.send_message("Heidi sagt: Komm in den Channel!", ephemeral=True)
return
voice_channel: VoiceChannel = member.voice.channel
await play_voice_line(interaction, voice_channel, board, sound)

View File

@ -1,9 +0,0 @@
#!/bin/sh
cd /home/christoph/HeidiBot
git pull
docker pull registry.gitlab.com/churl/heidibot
docker container rm -f heidibot
docker run -d --env-file /home/christoph/HeidiBot/.env --mount src=/home/christoph/HeidiBot/voicelines,target=/sounds,type=bind --name heidibot registry.gitlab.com/churl/heidibot
docker image prune -f

View File

@ -1,33 +0,0 @@
#!/usr/bin/env python3
import requests
import re
from bs4 import BeautifulSoup
class Models:
def __init__(self):
url_girls = "https://www.prosieben.de/tv/germanys-next-topmodel/models"
html_girls = requests.get(url_girls)
soup_girls = BeautifulSoup(html_girls.text, "html.parser")
girls_in = soup_girls.findAll("a", class_="candidate-in")
girls_out = soup_girls.findAll("a", class_="candidate-out")
self.girls_in = {girl.get("title").lower(): girl for girl in girls_in}
self.girls_out = {girl.get("title").lower(): girl for girl in girls_out}
self.girls = {**self.girls_in, **self.girls_out}
def get_in_names(self):
return self.girls_in.keys()
def get_out_names(self):
return self.girls_out.keys()
def get_image(self, name):
style = self.girls[name.lower()].find("figure", class_="teaser-img").get("style")
url = re.search(r"url\(.*\);", style).group()
return url[4:-9] + "562x996" # increase resolution

View File

@ -3,12 +3,3 @@ rich
discord.py # maintained again
pynacl # voice support
python-dotenv # discord guild secrets
# Webscraping
# requests
# beautifulsoup4
# Textgeneration
# torch
# numpy
# nltk

View File

@ -1,44 +0,0 @@
#!/usr/bin/env python3
from rich.traceback import install
install()
from abc import ABC, abstractmethod
# In Python it is generally not needed to use abstract classes, but I wanted to do it safely
class textgen(ABC):
@abstractmethod
def init(self, filename):
"""
filename - The file (same directory as textgen.py) that contains the training text
"""
raise NotImplementedError("Can't use abstract class")
@abstractmethod
def load(self):
"""
Load the trained markov chain from a precomputed file
"""
raise NotImplementedError("Can't use abstract class")
@abstractmethod
def train(self):
"""
Generate the markov chain, uses prefix length defined in init()
"""
raise NotImplementedError("Can't use abstract class")
@abstractmethod
def generate_sentence(self):
"""
Generate a series of words/characters until a . is generated
"""
raise NotImplementedError("Can't use abstract class")
@abstractmethod
def complete_sentence(self, prefix):
"""
Generate the rest of a sentence for a given beginning
"""
raise NotImplementedError("Can't use abstract class")

View File

@ -1,303 +0,0 @@
#!/usr/bin/env python3
import re, random
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from textgen import textgen
from torch import nn, optim
from rich.traceback import install
install()
# Model =======================================================================================
# https://towardsdatascience.com/text-generation-with-bi-lstm-in-pytorch-5fda6e7cc22c
# Embedding -> Bi-LSTM -> LSTM -> Linear
class Model(nn.ModuleList):
def __init__(self, args, device):
super(Model, self).__init__()
self.device = device
self.batch_size = args["batch_size"]
self.hidden_dim = args["hidden_dim"]
self.input_size = args["vocab_size"]
self.num_classes = args["vocab_size"]
self.sequence_len = args["window"]
# Dropout
self.dropout = nn.Dropout(0.25) # Don't need to set device for the layers as we transfer the whole model later
# Embedding layer
self.embedding = nn.Embedding(self.input_size, self.hidden_dim, padding_idx=0)
# Bi-LSTM
# Forward and backward
self.lstm_cell_forward = nn.LSTMCell(self.hidden_dim, self.hidden_dim)
self.lstm_cell_backward = nn.LSTMCell(self.hidden_dim, self.hidden_dim)
# LSTM layer
self.lstm_cell = nn.LSTMCell(self.hidden_dim * 2, self.hidden_dim * 2)
# Linear layer
self.linear = nn.Linear(self.hidden_dim * 2, self.num_classes)
def forward(self, x):
# Bi-LSTM
# hs = [batch_size x hidden_size]
# cs = [batch_size x hidden_size]
hs_forward = torch.zeros(x.size(0), self.hidden_dim).to(self.device) # Need to specify device here as this is not part of the model directly
cs_forward = torch.zeros(x.size(0), self.hidden_dim).to(self.device)
hs_backward = torch.zeros(x.size(0), self.hidden_dim).to(self.device)
cs_backward = torch.zeros(x.size(0), self.hidden_dim).to(self.device)
# LSTM
# hs = [batch_size x (hidden_size * 2)]
# cs = [batch_size x (hidden_size * 2)]
hs_lstm = torch.zeros(x.size(0), self.hidden_dim * 2).to(self.device)
cs_lstm = torch.zeros(x.size(0), self.hidden_dim * 2).to(self.device)
# Weights initialization
torch.nn.init.kaiming_normal_(hs_forward)
torch.nn.init.kaiming_normal_(cs_forward)
torch.nn.init.kaiming_normal_(hs_backward)
torch.nn.init.kaiming_normal_(cs_backward)
torch.nn.init.kaiming_normal_(hs_lstm)
torch.nn.init.kaiming_normal_(cs_lstm)
# From idx to embedding
out = self.embedding(x)
# Prepare the shape for LSTM Cells
out = out.view(self.sequence_len, x.size(0), -1)
forward = []
backward = []
# Unfolding Bi-LSTM
# Forward
for i in range(self.sequence_len):
hs_forward, cs_forward = self.lstm_cell_forward(out[i], (hs_forward, cs_forward))
forward.append(hs_forward)
# Backward
for i in reversed(range(self.sequence_len)):
hs_backward, cs_backward = self.lstm_cell_backward(out[i], (hs_backward, cs_backward))
backward.append(hs_backward)
# LSTM
for fwd, bwd in zip(forward, backward):
input_tensor = torch.cat((fwd, bwd), 1)
hs_lstm, cs_lstm = self.lstm_cell(input_tensor, (hs_lstm, cs_lstm))
# Last hidden state is passed through a linear layer
out = self.linear(hs_lstm)
return out
# =============================================================================================
class LSTMTextGenerator(textgen):
def __init__(self, windowsize):
self.windowsize = windowsize # We slide a window over the character sequence and look at the next letter,
# similar to the Markov chain order
def init(self, filename):
self.filename = filename
# Use this to generate one hot vector and filter characters
self.letters = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m",
"n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "ä", "ö", "ü", ".", " "]
with open(f"./textfiles/{filename}.txt", "r") as file:
lines = [line.lower() for line in file.readlines()] # lowercase list
text = " ".join(lines) # single string
self.charbase = [char for char in text if char in self.letters] # list of characters
# Select device
if torch.cuda.is_available():
dev = "cuda:0"
print("Selected GPU for LSTM")
else:
dev = "cpu"
print("Selected CPU for LSTM")
self.device = torch.device(dev)
# Init model
self.args = {
"window": self.windowsize,
"hidden_dim": 128,
"vocab_size": len(self.letters),
"batch_size": 128,
"learning_rate": 0.0005,
"num_epochs": 100
}
self.model = Model(self.args, self.device)
self.model.to(self.device) # All model layers need to use the correct tensors (cpu/gpu)
# Needed for both training and generation
self.__generate_char_sequences()
# Helper shit
def __char_to_idx(self, char):
return self.letters.index(char)
def __idx_to_char(self, idx):
return self.letters[idx]
def __generate_char_sequences(self):
# Example
# [[21, 20, 15],
# [12, 12, 14]]
prefixes = []
# Example
# [[1],
# [4]]
suffixes = []
print("Generating LSTM char sequences...")
for i in range(len(self.charbase) - self.windowsize - 1):
prefixes.append([self.__char_to_idx(char) for char in self.charbase[i:i+self.windowsize]])
suffixes += [self.__char_to_idx(char) for char in self.charbase[i+self.windowsize+1]] # Bit stupid wrapping this in a list but removes possible type error
# Enter numpy terretory NOW
self.prefixes = np.array(prefixes)
self.suffixes = np.array(suffixes)
print(f"Prefixes shape: {self.prefixes.shape}")
print(f"Suffixes shape: {self.suffixes.shape}")
print("Completed.")
# Interface shit
# TODO: Also save/load generated prefixes
def load(self):
print(f"Loading \"{self.filename}\" LSTM model with {len(self.charbase)} characters from file.")
self.model.load_state_dict(torch.load(f"weights/{self.filename}_lstm_model.pt"))
def train(self):
print(f"Training \"{self.filename}\" LSTM model with {len(self.charbase)} characters.")
# Optimizer initialization, RMSprop for RNN
optimizer = optim.RMSprop(self.model.parameters(), lr=self.args["learning_rate"])
# Defining number of batches
num_batches = int(len(self.prefixes) / self.args["batch_size"])
# Set model in training mode
self.model.train()
losses = []
# Training pahse
for epoch in range(self.args["num_epochs"]):
# Mini batches
for i in range(num_batches):
# Batch definition
try:
x_batch = self.prefixes[i * self.args["batch_size"]:(i + 1) * self.args["batch_size"]]
y_batch = self.suffixes[i * self.args["batch_size"]:(i + 1) * self.args["batch_size"]]
except:
x_batch = self.prefixes[i * self.args["batch_size"]:]
y_batch = self.suffixes[i * self.args["batch_size"]:]
# Convert numpy array into torch tensors
x = torch.from_numpy(x_batch).type(torch.long).to(self.device)
y = torch.from_numpy(y_batch).type(torch.long).to(self.device)
# Feed the model
y_pred = self.model(x)
# Loss calculation
loss = F.cross_entropy(y_pred, y.squeeze()).to(self.device)
losses += [loss.item()]
# Clean gradients
optimizer.zero_grad()
# Calculate gradientes
loss.backward()
# Updated parameters
optimizer.step()
print("Epoch: %d , loss: %.5f " % (epoch, loss.item()))
torch.save(self.model.state_dict(), f"weights/{self.filename}_lstm_model.pt")
print(f"Saved \"{self.filename}\" LSTM model to file")
plt.plot(np.arange(0, len(losses)), losses)
plt.title(self.filename)
plt.show()
def generate_sentence(self):
# Randomly is selected the index from the set of sequences
start = np.random.randint(0, len(self.prefixes)-1)
# Convert back to string to match complete_sentence
pattern = "".join([self.__idx_to_char(char) for char in self.prefixes[start]]) # random sequence from the training text
return self.complete_sentence(pattern)
def complete_sentence(self, prefix):
print("Prefix:", prefix)
# Convert to indexes np.array
pattern = np.array([self.__char_to_idx(char) for char in prefix])
# Set the model in evalulation mode
self.model.eval()
# Define the softmax function
softmax = nn.Softmax(dim=1).to(self.device)
# In full_prediction we will save the complete prediction
full_prediction = pattern.copy()
print("Generating sentence...")
# Predic the next characters one by one, append chars to the starting pattern until . is reached, max 500 iterations
for _ in range(500):
# the numpy patterns is transformed into a tesor-type and reshaped
pattern = torch.from_numpy(pattern).type(torch.long).to(self.device)
pattern = pattern.view(1,-1)
# make a prediction given the pattern
prediction = self.model(pattern)
# it is applied the softmax function to the predicted tensor
prediction = softmax(prediction)
# the prediction tensor is transformed into a numpy array
prediction = prediction.squeeze().detach().cpu().numpy()
# it is taken the idx with the highest probability
arg_max = np.argmax(prediction)
# the current pattern tensor is transformed into numpy array
pattern = pattern.squeeze().detach().cpu().numpy()
# the window is sliced 1 character to the right
pattern = pattern[1:]
# the new pattern is composed by the "old" pattern + the predicted character
pattern = np.append(pattern, arg_max)
# the full prediction is saved
full_prediction = np.append(full_prediction, arg_max)
# Stop on . character
if self.__idx_to_char(arg_max) == ".":
break
full_prediction = "".join([self.__idx_to_char(value) for value in full_prediction])
print("Generated:", full_prediction)
return full_prediction

View File

@ -1,82 +0,0 @@
#!/usr/bin/env python3
import re
import random
from textgen import textgen
from rich.traceback import install
install()
# NOTE: This is word based, not character based
# TODO: Serialize and save/load model (don't train on the server)
# TODO: Maybe extract sentence beginnings and use them as starters?
class MarkovTextGenerator(textgen):
# The greater the order (prefix length), the lesser the variation in generation, but the better the sentences (generally).
# If the prefix length is high there are less options to choose from, so the sentences are very close to the training text.
def __init__(self, order): # Set order here for better interface (only needed for markov model)
self.order = order
def init(self, filename): # Filename is needed for every type of model so it's part of the interface
with open(f"./textfiles/{filename}.txt", "r") as file:
# Remove all characters except a-zäöüß'.,
self.wordbase = re.sub(r"[^a-zäöüß'.,]+", " ", file.read().lower()).split()
self.word_table = dict()
def load(self):
print(f"Loaded Markov chain of order {self.order} with {len(self.wordbase)} words from file.")
def train(self):
print(f"Training Markov chain of order {self.order} with {len(self.wordbase)} words.")
# init the frequencies
for i in range(len(self.wordbase) - self.order - 1): # Look at every word in range
prefix = tuple(self.wordbase[i:i+self.order]) # Look at the next self.order words from current position
suffix = self.wordbase[i+self.order] # The next word is the suffix
if prefix not in self.word_table: # New option wooo
self.word_table[prefix] = []
# if suffix not in self.table[prefix]: # disable for probabilities: if the suffixes are in the list multiple times they are more common
self.word_table[prefix].append(suffix)
print(f"Generated suffixes for {len(self.word_table)} prefixes.")
# def generate_random(self, n):
# fword = random.choice(list(self.word_table.keys())) # Random first word
# output = [*fword]
# for _ in range(self.order, n):
# output.append(self.generate_word_by_word(tuple(output[-self.order :])))
# return output
def generate_suffix_for_prefix(self, prefix: tuple):
if len(prefix) > self.order: # In this case we look at the last self.order elements of prefix
prefix = prefix[len(prefix)-self.order-1:-1]
if prefix not in self.word_table: # In this case we need to choose a possible suffix from the last word in the prefix (if prefix too short for example)
print(f"Prefix {prefix} not in table")
for key in self.word_table.keys():
if key[-1] == prefix[-1]:
return random.choice(self.word_table[key])
return random.choice(self.word_table[prefix])
def generate_sentence(self):
fword = random.choice(list(self.word_table.keys()))
output = [*fword]
while "." not in output[-1]:
output.append(self.generate_suffix_for_prefix(tuple(output[-self.order:])))
return output
def complete_sentence(self, prefix):
output = [*prefix]
while "." not in output[-1]:
output.append(self.generate_suffix_for_prefix(tuple(output[-self.order:])))
return output