PyConES24/src/retailtwin/bootstrap.py

270 lines
8.5 KiB
Python
Raw Normal View History

2024-05-08 23:55:08 +02:00
import json
import string
from datetime import datetime, timedelta
from itertools import repeat
from pathlib import Path
from random import choice, choices, randint
from uuid import uuid4
import polars as pl
from sqlalchemy import select
from sqlalchemy.orm import Session
import retailtwin
from retailtwin.utils import db_uri_from_session
from retailtwin.models import (
Discount,
Item,
ItemBatch,
ItemOnShelf,
Location,
LocationType,
Provider,
TaskOwner,
)
# Some configuration parameters.
PACKAGE_ROOT = Path(retailtwin.__file__).parent
PRODUCT_LIST_FILE = "data/products.csv"
DISCOUNT_LIST_FILE = "data/discounts.csv"
RANDOM_PEOPLE_FILE = "data/random_people.csv"
NUM_LOCATIONS = 100
LOCATION_TYPES = ["store", "warehouse"]
LOCATION_TYPES_WEIGHTS = [10, 1]
NUM_CUSTOMERS = 1_000_000
CUSTOMER_BATCH_SIZE = 10_000
# Some utility functions
def read_products() -> pl.DataFrame:
"""Load and sanitize the dummy product list stored in a file as part of the package
Returns:
pl.DataFrame: Polars dataframe with the products
"""
return (
pl.read_csv((PACKAGE_ROOT / PRODUCT_LIST_FILE).resolve())
.with_columns(
product=pl.col("product").str.strip(" "),
package=pl.col(" package").str.strip(" "),
price=pl.col(" price").str.strip(" ").str.strip("$"),
provider=pl.col(" provider").str.strip(" "),
)
.select(
[pl.col("product"), pl.col("package"), pl.col("provider"), pl.col("price")]
)
)
def read_discounts() -> pl.DataFrame:
"""Load the discount list stored in a file as part of the package
Returns:
pl.DataFrame: _description_
"""
return pl.read_csv((PACKAGE_ROOT / DISCOUNT_LIST_FILE).resolve())
def read_people() -> pl.DataFrame:
"""Load the dummy customer list stored in a file as part of the package
Returns:
pl.DataFrame: _description_
"""
return pl.read_csv((PACKAGE_ROOT / RANDOM_PEOPLE_FILE).resolve())
def bootstrap_discounts(session: Session):
"""Load the discount table into the database
Args:
session (Session): SQLAlchemy ORM session
"""
for discount in read_discounts().iter_rows():
session.add(Discount(name=discount[0], definition=json.loads(discount[1])))
session.commit()
def bootstrap_providers(session: Session):
"""Load the providers table with data from the product list. Only the name changes,
and other provider information is filled with the same data or mocked.
Args:
session (Session): SQLAlchemy ORM session
"""
ascii_upper = [s.upper() for s in string.ascii_letters]
for provider in read_products().select(pl.col("provider").unique()).to_series():
session.add(
Provider(
name=provider,
address="Fake address street, number XX",
phone="+1 555 555 55 55",
vat=f"{randint(10000000,99999999)}{choice(ascii_upper)}",
)
)
session.commit()
def bootstrap_items(session: Session):
"""Load data into the Items table from the product list
Args:
session (Session): SQLAlchemy ORM sessoin
"""
for data in read_products().select(pl.all()).iter_rows():
volume = randint(2, 10)
provider = session.scalar(select(Provider).where(Provider.name == data[2])).id
session.add(
Item(
name=data[0],
upc=randint(0, 999999999999),
package=data[1],
current=True,
provider=provider,
volume_unpacked=volume,
volume_packed=volume - 1,
)
)
session.commit()
def bootstrap_locations(session: Session):
"""Create NUM_LOCATIONS of LOCATION_TYPES, with LOCATION_TYPES_WEIGHTS providing
the proportion of each location. If location is of type warehouse, capacity is
around 10 times bigger.
Args:
session (Session): SQLAlchemy ORM session
"""
for i, ltype in enumerate(LOCATION_TYPES):
session.add(LocationType(name=ltype, retail=(ltype == "store")))
session.commit()
for i in range(NUM_LOCATIONS):
c = choices(LOCATION_TYPES, LOCATION_TYPES_WEIGHTS)
loc = session.scalar(select(LocationType).where(LocationType.name == c[0]))
session.add(
Location(
loctype=loc.id,
name=f"Location {i} {loc.name}",
capacity=randint(50_000, 100_000)
* (1 + 9 * (int(loc.name == "warehouse"))),
)
)
session.commit()
def bootstrap_clients(session: Session):
"""Load data into the customers table in batches.
Args:
session (Session): SQLAlchemy ORM session
"""
connection_uri = db_uri_from_session(session)
people = read_people()
for i in range(NUM_CUSTOMERS // CUSTOMER_BATCH_SIZE):
print(f"Write batch {i} of {NUM_CUSTOMERS // CUSTOMER_BATCH_SIZE}")
data = (
pl.DataFrame(
{
"name": people.select(pl.col("name"))
.sample(CUSTOMER_BATCH_SIZE, with_replacement=True, shuffle=True)
.to_series(),
"middlename": people.select(pl.col("middlename"))
.sample(CUSTOMER_BATCH_SIZE, with_replacement=True, shuffle=True)
.to_series(),
"surname": people.select(pl.col("surname"))
.sample(CUSTOMER_BATCH_SIZE, with_replacement=True, shuffle=True)
.to_series(),
}
)
.with_columns(
name=pl.concat_str(
[pl.col("name"), pl.col("middlename"), pl.col("surname")],
separator=" ",
)
)
.select(pl.col("name"))
)
pl.DataFrame(
{
"document": [
f"{str(randint(0, 99999999)).zfill(8)}"
for _ in range(CUSTOMER_BATCH_SIZE)
],
"info": [json.dumps({"name": name[0]}) for name in data.iter_rows()],
},
).write_database(
"customers", connection_uri, if_exists="append", engine="sqlalchemy"
)
def bootstrap_stock(session: Session):
"""Load items as stock in each location. The total weight is estimated to not to
stock over capacity.
Args:
session (Session): SQLAlchemy ORM session
"""
products = read_products()
db_uri = db_uri_from_session(session)
items_df = pl.read_database("select * from items", db_uri, engine="adbc")
locations_df = pl.read_database("select * from locations", db_uri, engine="adbc")
discounts = list(
pl.read_database("select id from discounts", db_uri, engine="adbc")
.select(pl.col("id"))
.to_series()
) + [None]
weights = [1 for _ in range(len(discounts) - 1)] + [50]
average_volume_unpacked = items_df.select(pl.col("volume_unpacked").mean())[0, 0]
for location in locations_df.iter_rows():
print(f"Stocking location {location[0]} from {len(locations_df)}")
stock_quantity = int(location[3] / average_volume_unpacked / len(products))
with_prices = items_df.join(
products,
left_on=["name", "package"],
right_on=["product", "package"],
)
# Create the batch too
for lot_id, unit_price, sku in zip(
repeat(str(uuid4())[:23]),
with_prices.select([pl.col("price")]).to_series(),
with_prices.select([pl.col("sku")]).to_series(),
):
batch = ItemBatch(
sku=sku,
lot=lot_id,
order=None,
received=datetime.now(),
unit_cost=unit_price, # TODO: There's no margin for the moment
price=unit_price,
best_until=datetime.now() + timedelta(days=30),
quantity=stock_quantity,
)
session.add(batch)
session.commit()
random_discount = choices(discounts, weights, k=1)[0]
on_shelf = ItemOnShelf(
batch=batch.id,
discount=random_discount if random_discount else None,
quantity=stock_quantity,
location=location[0],
)
session.add(on_shelf)
session.commit()
def bootstrap_taskowners(session: Session):
for name in ["stocker", "cashier", "manager", "warehouse"]:
taskowner = TaskOwner(name=name)
session.add(taskowner)
session.commit()