From b1d05cf7c8f4c3386cd5ade28859d880695efa38 Mon Sep 17 00:00:00 2001 From: Christoph Urlacher Date: Tue, 20 Feb 2024 13:45:16 +0100 Subject: [PATCH] Improve typing --- database_utils.py | 43 ++++++++++++++++--------------- formula10.py | 54 +++++++++++++++++++-------------------- model.py | 61 ++++++++++++++++++++++---------------------- template_model.py | 36 ++++++++++++-------------- templates/base.jinja | 2 +- 5 files changed, 97 insertions(+), 99 deletions(-) diff --git a/database_utils.py b/database_utils.py index 1fba9ff..c800cd4 100644 --- a/database_utils.py +++ b/database_utils.py @@ -1,10 +1,11 @@ import csv import os.path - -from model import * +from typing import List, Any +from flask_sqlalchemy import SQLAlchemy +from model import Team, Driver, Race, User, RaceResult, RaceGuess, TeamWinners, PodiumDrivers, SeasonGuess -def load_csv(filename): +def load_csv(filename: str) -> List[List[str]]: if not os.path.exists(filename): print(f"Could not load data from file {filename}, as it doesn't exist!") return [] @@ -15,7 +16,7 @@ def load_csv(filename): return list(reader) -def write_csv(filename, objects): +def write_csv(filename: str, objects: List[Any]): if len(objects) == 0: print(f"Could not write objects to file {filename}, as no objects were given!") return @@ -28,15 +29,15 @@ def write_csv(filename, objects): # Reload static database data, this has to be called from the app context -def reload_static_data(db): +def reload_static_data(db: SQLAlchemy): print("Initializing Database with Static Values...") # Create it (if it doesn't exist!) db.create_all() # Clear static data - Team.query.delete() - Driver.query.delete() - Race.query.delete() + db.session.query(Team).delete() + db.session.query(Driver).delete() + db.session.query(Race).delete() # Reload static data for row in load_csv("static_data/teams.csv"): @@ -49,18 +50,18 @@ def reload_static_data(db): db.session.commit() -def reload_dynamic_data(db): +def reload_dynamic_data(db: SQLAlchemy): print("Initializing Database with Dynamic Values...") # Create it (if it doesn't exist!) db.create_all() # Clear dynamic data - User.query.delete() - RaceResult.query.delete() - RaceGuess.query.delete() - TeamWinners.query.delete() - PodiumDrivers.query.delete() - SeasonGuess.query.delete() + db.session.query(User).delete() + db.session.query(RaceResult).delete() + db.session.query(RaceGuess).delete() + db.session.query(TeamWinners).delete() + db.session.query(PodiumDrivers).delete() + db.session.query(SeasonGuess).delete() # Reload dynamic data for row in load_csv("dynamic_data/users.csv"): @@ -82,12 +83,12 @@ def reload_dynamic_data(db): def export_dynamic_data(): print("Exporting Userdata...") - users = User.query.all() - raceresults = RaceResult.query.all() - raceguesses = RaceGuess.query.all() - teamwinners = TeamWinners.query.all() - podiumdrivers = PodiumDrivers.query.all() - seasonguesses = SeasonGuess.query.all() + users: List[User] = User.query.all() + raceresults: List[RaceResult] = RaceResult.query.all() + raceguesses: List[RaceGuess] = RaceGuess.query.all() + teamwinners: List[TeamWinners] = TeamWinners.query.all() + podiumdrivers: List[PodiumDrivers] = PodiumDrivers.query.all() + seasonguesses: List[SeasonGuess] = SeasonGuess.query.all() write_csv("dynamic_data/users.csv", users) write_csv("dynamic_data/raceresults.csv", raceresults) diff --git a/formula10.py b/formula10.py index 7700ecf..61edd21 100644 --- a/formula10.py +++ b/formula10.py @@ -1,6 +1,6 @@ from urllib.parse import unquote - from flask import Flask, render_template, request, redirect +from werkzeug import Response from model import * from database_utils import reload_static_data, reload_dynamic_data, export_dynamic_data from template_model import TemplateModel @@ -30,42 +30,42 @@ db.init_app(app) @app.route("/") -def root(): - return race_active_user("Everyone") +def root() -> Response: + return redirect("/race/Everyone") -@app.route("/save/all", strict_slashes=False) -def save(): +@app.route("/save/all") +def save() -> Response: export_dynamic_data() return redirect("/") @app.route("/load/all") -def load(): +def load() -> Response: reload_static_data(db) reload_dynamic_data(db) return redirect("/") @app.route("/load/static") -def load_static(): +def load_static() -> Response: reload_static_data(db) return redirect("/") @app.route("/load/dynamic") -def load_dynamic(): +def load_dynamic() -> Response: reload_dynamic_data(db) return redirect("/") @app.route("/race") -def race_root(): +def race_root() -> Response: return redirect("/race/Everyone") @app.route("/race/") -def race_active_user(user_name: str): +def race_active_user(user_name: str) -> str: user_name = unquote(user_name) model = TemplateModel() return render_template("race.jinja", @@ -74,7 +74,7 @@ def race_active_user(user_name: str): @app.route("/race-guess//", methods=["POST"]) -def race_guess_post(race_name: str, user_name: str): +def race_guess_post(race_name: str, user_name: str) -> Response: race_name = unquote(race_name) user_name = unquote(user_name) @@ -88,7 +88,7 @@ def race_guess_post(race_name: str, user_name: str): print("Error: Can't guess race result if the race result is already known!") return redirect(f"/race/{quote(user_name)}") - raceguess: RaceGuess | None = RaceGuess.query.filter_by(user_name=user_name, race_name=race_name).first() + raceguess: RaceGuess | None = db.session.query(RaceGuess).filter_by(user_name=user_name, race_name=race_name).first() if raceguess is None: raceguess = RaceGuess() @@ -104,12 +104,12 @@ def race_guess_post(race_name: str, user_name: str): @app.route("/season") -def season_root(): +def season_root() -> Response: return redirect("/season/Everyone") @app.route("/season/") -def season_active_user(user_name: str): +def season_active_user(user_name: str) -> str: user_name = unquote(user_name) model = TemplateModel() return render_template("season.jinja", @@ -118,7 +118,7 @@ def season_active_user(user_name: str): @app.route("/season-guess/", methods=["POST"]) -def season_guess_post(user_name: str): +def season_guess_post(user_name: str) -> Response: user_name = unquote(user_name) guesses: List[str | None] = [ request.form.get("hottakeselect"), @@ -129,7 +129,7 @@ def season_guess_post(user_name: str): request.form.get("lostselect") ] teamwinnerguesses: List[str | None] = [ - request.form.get(f"teamwinner-{team.name}") for team in Team.query.all() + request.form.get(f"teamwinner-{team.name}") for team in db.session.query(Team).all() ] podiumdriverguesses: List[str] = request.form.getlist("podiumdrivers") @@ -137,7 +137,7 @@ def season_guess_post(user_name: str): print("Error: /guessseason could not obtain request data!") return redirect(f"/season/{quote(user_name)}") - seasonguess: SeasonGuess | None = SeasonGuess.query.filter_by(user_name=user_name).first() + seasonguess: SeasonGuess | None = db.session.query(SeasonGuess).filter_by(user_name=user_name).first() teamwinners: TeamWinners | None = seasonguess.team_winners if seasonguess is not None else None podiumdrivers: PodiumDrivers | None = seasonguess.podium_drivers if seasonguess is not None else None @@ -176,12 +176,12 @@ def season_guess_post(user_name: str): @app.route("/result") -def result_root(): +def result_root() -> Response: return redirect("/result/Current") @app.route("/result/") -def result_active_race(race_name: str): +def result_active_race(race_name: str) -> str: race_name = unquote(race_name) model = TemplateModel() return render_template("enter.jinja", @@ -190,7 +190,7 @@ def result_active_race(race_name: str): @app.route("/result-enter/", methods=["POST"]) -def result_enter_post(result_race_name: str): +def result_enter_post(result_race_name: str) -> Response: result_race_name = unquote(result_race_name) pxxs: List[str] = request.form.getlist("pxxdrivers") dnfs: List[str] = request.form.getlist("dnf-drivers") @@ -200,7 +200,7 @@ def result_enter_post(result_race_name: str): pxxs_dict: Dict[str, str] = {str(position + 1): driver for position, driver in enumerate(pxxs)} dnfs_dict: Dict[str, str] = {str(position + 1): driver for position, driver in enumerate(pxxs) if driver in dnfs} - raceresult: RaceResult | None = RaceResult.query.filter_by(race_name=result_race_name).first() + raceresult: RaceResult | None = db.session.query(RaceResult).filter_by(race_name=result_race_name).first() if raceresult is None: raceresult = RaceResult() @@ -212,7 +212,7 @@ def result_enter_post(result_race_name: str): raceresult.excluded_driver_names = excludes db.session.commit() - race: Race | None = Race.query.filter_by(name=result_race_name).first() + race: Race | None = db.session.query(Race).filter_by(name=result_race_name).first() if race is None: print("Error: Can't redirect to /enter/ because race couldn't be found") return redirect(f"/result/Current") @@ -221,7 +221,7 @@ def result_enter_post(result_race_name: str): @app.route("/user") -def user_root(): +def user_root() -> str: users: List[User] = User.query.all() return render_template("users.jinja", @@ -229,14 +229,14 @@ def user_root(): @app.route("/user-add", methods=["POST"]) -def user_add_post(): +def user_add_post() -> Response: username: str | None = request.form.get("select-add-user") if username is None or len(username) == 0: print(f"Not adding user, since no username was received") - return user_root() + return redirect("/user") - if len(User.query.filter_by(name=username).all()) > 0: + if len(db.session.query(User).filter_by(name=username).all()) > 0: print(f"Not adding user {username}: Already exists!") return redirect("/user") @@ -249,7 +249,7 @@ def user_add_post(): @app.route("/user-delete", methods=["POST"]) -def user_delete_post(): +def user_delete_post() -> Response: username = request.form.get("select-delete-user") if username is None or len(username) == 0: diff --git a/model.py b/model.py index 6aa4ad8..13ecdd8 100644 --- a/model.py +++ b/model.py @@ -1,13 +1,12 @@ import json from datetime import datetime -from typing import List, Dict +from typing import Any, List, Dict from urllib.parse import quote - from flask_sqlalchemy import SQLAlchemy from sqlalchemy import Integer, String, DateTime, ForeignKey from sqlalchemy.orm import Mapped, mapped_column, relationship -db = SQLAlchemy() +db: SQLAlchemy = SQLAlchemy() #################################### # Static Data (Defined in Backend) # @@ -21,7 +20,7 @@ class Race(db.Model): """ __tablename__ = "race" - def from_csv(self, row): + def from_csv(self, row: List[str]): self.name = str(row[0]) self.number = int(row[1]) self.date = datetime.strptime(row[2], "%Y-%m-%d") @@ -44,7 +43,7 @@ class Team(db.Model): """ __tablename__ = "team" - def from_csv(self, row): + def from_csv(self, row: List[str]): self.name = str(row[0]) return self @@ -58,7 +57,7 @@ class Driver(db.Model): """ __tablename__ = "driver" - def from_csv(self, row): + def from_csv(self, row: List[str]): self.name = str(row[0]) self.abbr = str(row[1]) self.team_name = str(row[2]) @@ -86,11 +85,11 @@ class User(db.Model): __tablename__ = "user" __csv_header__ = ["name"] - def from_csv(self, row): + def from_csv(self, row: List[str]): self.name = str(row[0]) return self - def to_csv(self): + def to_csv(self) -> List[Any]: return [ self.name ] @@ -111,14 +110,14 @@ class RaceResult(db.Model): __allow_unmapped__ = True # TODO: Used for json conversion, move this to some other class instead __csv_header__ = ["race_name", "pxx_driver_names_json", "dnf_driver_names_json", "excluded_driver_names_json"] - def from_csv(self, row): + def from_csv(self, row: List[str]): self.race_name = str(row[0]) self.pxx_driver_names_json = str(row[1]) self.dnf_driver_names_json = str(row[2]) self.excluded_driver_names_json = str(row[3]) return self - def to_csv(self): + def to_csv(self) -> List[Any]: return [ self.race_name, self.pxx_driver_names_json, @@ -166,7 +165,7 @@ class RaceResult(db.Model): if self._pxx_drivers is None: self._pxx_drivers = dict() for position, driver_name in self.pxx_driver_names.items(): - driver = Driver.query.filter_by(name=driver_name).first() + driver: Driver | None = db.session.query(Driver).filter_by(name=driver_name).first() if driver is None: raise Exception(f"Error: Couldn't find driver with id {driver_name}") @@ -176,11 +175,11 @@ class RaceResult(db.Model): @property def pxx_drivers_values(self) -> List[Driver]: - drivers: List[Driver] = [] + drivers: List[Driver] = list() # I don't know what order dict.values() etc. will return... for position in range(1, 21): - drivers += [self.pxx_drivers[str(position)]] + drivers.append(self.pxx_drivers[str(position)]) return drivers @@ -189,7 +188,7 @@ class RaceResult(db.Model): if self._dnf_drivers is None: self._dnf_drivers = dict() for position, driver_name in self.dnf_driver_names.items(): - driver = Driver.query.filter_by(name=driver_name).first() + driver: Driver | None = db.session.query(Driver).filter_by(name=driver_name).first() if driver is None: raise Exception(f"Error: Couldn't find driver with id {driver_name}") @@ -200,13 +199,13 @@ class RaceResult(db.Model): @property def excluded_drivers(self) -> List[Driver]: if self._excluded_drivers is None: - self._excluded_drivers = [] + self._excluded_drivers = list() for driver_name in self.excluded_driver_names: - driver = Driver.query.filter_by(name=driver_name).first() + driver: Driver | None = db.session.query(Driver).filter_by(name=driver_name).first() if driver is None: raise Exception(f"Error: Couldn't find driver with id {driver_name}") - self._excluded_drivers += [driver] + self._excluded_drivers.append(driver) return self._excluded_drivers @@ -230,14 +229,14 @@ class RaceGuess(db.Model): __tablename__ = "raceguess" __csv_header__ = ["user_name", "race_name", "pxx_driver_name", "dnf_driver_name"] - def from_csv(self, row): + def from_csv(self, row: List[str]): self.user_name = str(row[0]) self.race_name = str(row[1]) self.pxx_driver_name = str(row[2]) self.dnf_driver_name = str(row[3]) return self - def to_csv(self): + def to_csv(self) -> List[Any]: return [ self.user_name, self.race_name, @@ -265,12 +264,12 @@ class TeamWinners(db.Model): __allow_unmapped__ = True __csv_header__ = ["user_name", "teamwinner_driver_names_json"] - def from_csv(self, row): + def from_csv(self, row: List[str]): self.user_name = str(row[0]) self.teamwinner_driver_names_json = str(row[1]) return self - def to_csv(self): + def to_csv(self) -> List[Any]: return [ self.user_name, self.teamwinner_driver_names_json @@ -294,13 +293,13 @@ class TeamWinners(db.Model): @property def teamwinners(self) -> List[Driver]: if self._teamwinner_drivers is None: - self._teamwinner_drivers = [] + self._teamwinner_drivers = list() for driver_name in self.teamwinner_driver_names: - driver = Driver.query.filter_by(name=driver_name).first() + driver: Driver | None = db.session.query(Driver).filter_by(name=driver_name).first() if driver is None: raise Exception(f"Error: Couldn't find driver with id {driver_name}") - self._teamwinner_drivers += [driver] + self._teamwinner_drivers.append(driver) return self._teamwinner_drivers @@ -313,12 +312,12 @@ class PodiumDrivers(db.Model): __allow_unmapped__ = True __csv_header__ = ["user_name", "podium_driver_names_json"] - def from_csv(self, row): + def from_csv(self, row: List[str]): self.user_name = str(row[0]) self.podium_driver_names_json = str(row[1]) return self - def to_csv(self): + def to_csv(self) -> List[Any]: return [ self.user_name, self.podium_driver_names_json @@ -342,13 +341,13 @@ class PodiumDrivers(db.Model): @property def podium_drivers(self) -> List[Driver]: if self._podium_drivers is None: - self._podium_drivers = [] + self._podium_drivers = list() for driver_name in self.podium_driver_names: - driver = Driver.query.filter_by(name=driver_name).first() + driver: Driver | None = db.session.query(Driver).filter_by(name=driver_name).first() if driver is None: raise Exception(f"Error: Couldn't find driver with id {driver_name}") - self._podium_drivers += [driver] + self._podium_drivers.append(driver) return self._podium_drivers @@ -362,7 +361,7 @@ class SeasonGuess(db.Model): "overtake_driver_name", "dnf_driver_name", "gained_driver_name", "lost_driver_name", "team_winners_id", "podium_drivers_id"] - def from_csv(self, row): + def from_csv(self, row: List[str]): self.user_name = str(row[0]) # Also used as foreign key for teamwinners + podiumdrivers self.hot_take = str(row[1]) self.p2_team_name = str(row[2]) @@ -374,7 +373,7 @@ class SeasonGuess(db.Model): self.podium_drivers_id = str(row[8]) return self - def to_csv(self): + def to_csv(self) -> List[Any]: return [ self.user_name, self.hot_take, diff --git a/template_model.py b/template_model.py index d6cc430..2a6a555 100644 --- a/template_model.py +++ b/template_model.py @@ -1,8 +1,6 @@ -from typing import List, Iterable, Callable, TypeVar, Dict, overload, Any - +from typing import List, Iterable, Callable, TypeVar, Dict, overload from sqlalchemy import desc - -from model import User, RaceResult, RaceGuess, Race, Driver, Team, SeasonGuess +from model import User, RaceResult, RaceGuess, Race, Driver, Team, SeasonGuess, db _T = TypeVar("_T") @@ -73,7 +71,7 @@ class TemplateModel: Returns a list of all users in the database. """ if self._all_users is None: - self._all_users = User.query.all() + self._all_users = db.session.query(User).all() return self._all_users @@ -106,7 +104,7 @@ class TemplateModel: Returns a list of all race results in the database, in descending order (most recent first). """ if self._all_race_results is None: - self._all_race_results = RaceResult.query.join(RaceResult.race).order_by(desc(Race.number)).all() + self._all_race_results = db.session.query(RaceResult).join(RaceResult.race).order_by(desc(Race.number)).all() return self._all_race_results @@ -122,26 +120,26 @@ class TemplateModel: Returns a list of all race guesses in the database. """ if self._all_race_guesses is None: - self._all_race_guesses = RaceGuess.query.all() + self._all_race_guesses = db.session.query(RaceGuess).all() return self._all_race_guesses @overload - def race_guesses_by(self, *, user_name) -> List[RaceGuess]: + def race_guesses_by(self, *, user_name: str) -> List[RaceGuess]: """ Returns a list of all race guesses made by a specific user. """ return self.race_guesses_by(user_name=user_name) @overload - def race_guesses_by(self, *, race_name) -> List[RaceGuess]: + def race_guesses_by(self, *, race_name: str) -> List[RaceGuess]: """ Returns a list of all race guesses made for a specific race. """ return self.race_guesses_by(race_name=race_name) @overload - def race_guesses_by(self, *, user_name, race_name) -> RaceGuess | None: + def race_guesses_by(self, *, user_name: str, race_name: str) -> RaceGuess | None: """ Returns a single race guess by a specific user for a specific race, or None, if this guess doesn't exist. """ @@ -154,7 +152,7 @@ class TemplateModel: """ return self.race_guesses_by() - def race_guesses_by(self, *, user_name=None, race_name=None) -> RaceGuess | List[RaceGuess] | Dict[str, Dict[str, RaceGuess]] | None: + def race_guesses_by(self, *, user_name: str | None = None, race_name: str | None = None) -> RaceGuess | List[RaceGuess] | Dict[str, Dict[str, RaceGuess]] | None: # List of all guesses by a single user if user_name is not None and race_name is None: predicate: Callable[[RaceGuess], bool] = lambda guess: guess.user_name == user_name @@ -187,12 +185,12 @@ class TemplateModel: def all_season_guesses(self) -> List[SeasonGuess]: if self._all_season_guesses is None: - self._all_season_guesses = SeasonGuess.query.all() + self._all_season_guesses = db.session.query(SeasonGuess).all() return self._all_season_guesses @overload - def season_guesses_by(self, *, user_name) -> SeasonGuess: + def season_guesses_by(self, *, user_name: str) -> SeasonGuess: """ Returns the season guess made by a specific user. """ @@ -205,7 +203,7 @@ class TemplateModel: """ return self.season_guesses_by() - def season_guesses_by(self, *, user_name=None) -> SeasonGuess | Dict[str, SeasonGuess] | None: + def season_guesses_by(self, *, user_name: str | None = None) -> SeasonGuess | Dict[str, SeasonGuess] | None: if user_name is not None: predicate: Callable[[SeasonGuess], bool] = lambda guess: guess.user_name == user_name return find_single_or_none(predicate, self.all_season_guesses()) @@ -226,7 +224,7 @@ class TemplateModel: Returns a list of all races in the database. """ if self._all_races is None: - self._all_races = Race.query.order_by(desc(Race.number)).all() + self._all_races = db.session.query(Race).order_by(desc(Race.number)).all() return self._all_races @@ -248,7 +246,7 @@ class TemplateModel: Returns a list of all teams in the database. """ if self._all_teams is None: - self._all_teams = Team.query.all() + self._all_teams = db.session.query(Team).all() return self._all_teams @@ -257,7 +255,7 @@ class TemplateModel: Returns a list of all drivers in the database, including the NONE driver. """ if self._all_drivers is None: - self._all_drivers = Driver.query.all() + self._all_drivers = db.session.query(Driver).all() return self._all_drivers @@ -269,7 +267,7 @@ class TemplateModel: return find_multiple(predicate, self.all_drivers()) @overload - def drivers_by(self, *, team_name) -> List[Driver]: + def drivers_by(self, *, team_name: str) -> List[Driver]: """ Returns a list of all drivers driving for a certain team. """ @@ -282,7 +280,7 @@ class TemplateModel: """ return self.drivers_by() - def drivers_by(self, *, team_name=None) -> List[Driver] | Dict[str, List[Driver]]: + def drivers_by(self, *, team_name: str | None = None) -> List[Driver] | Dict[str, List[Driver]]: if team_name is not None: predicate: Callable[[Driver], bool] = lambda driver: driver.team.name == team_name return find_multiple(predicate, self.all_drivers_except_none(), 2) diff --git a/templates/base.jinja b/templates/base.jinja index 7001846..a2c6e28 100644 --- a/templates/base.jinja +++ b/templates/base.jinja @@ -146,7 +146,7 @@ P{{ result.race.pxx + 3 }}: {{ result.pxx(3).abbr }} -