1620 lines
63 KiB
Python
1620 lines
63 KiB
Python
"""
|
||
Lifecycle-Aware Forecast Engine
|
||
|
||
Generates 90-day per-product daily sales forecasts using analogous lifecycle
|
||
curves learned from historical brand/category launch patterns.
|
||
|
||
Usage:
|
||
python forecast_engine.py
|
||
python forecast_engine.py --backfill 30
|
||
|
||
Environment variables (from .env):
|
||
DB_HOST, DB_USER, DB_PASSWORD, DB_NAME, DB_PORT (default 5432)
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import json
|
||
import time
|
||
import logging
|
||
from datetime import datetime, date, timedelta
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
import psycopg2
|
||
import psycopg2.extras
|
||
import psycopg2.extensions
|
||
from scipy.optimize import curve_fit
|
||
from statsmodels.tsa.holtwinters import SimpleExpSmoothing, Holt
|
||
|
||
# Register numpy type adapters so psycopg2 can serialize them to SQL
|
||
psycopg2.extensions.register_adapter(np.float64, lambda x: psycopg2.extensions.AsIs(float(x)))
|
||
psycopg2.extensions.register_adapter(np.float32, lambda x: psycopg2.extensions.AsIs(float(x)))
|
||
psycopg2.extensions.register_adapter(np.int64, lambda x: psycopg2.extensions.AsIs(int(x)))
|
||
psycopg2.extensions.register_adapter(np.int32, lambda x: psycopg2.extensions.AsIs(int(x)))
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Config
|
||
# ---------------------------------------------------------------------------
|
||
FORECAST_HORIZON_DAYS = 90
|
||
CURVE_HISTORY_DAYS = 730 # 2 years of launches to build reference curves
|
||
CURVE_WINDOW_WEEKS = 13 # Track decay for 13 weeks (91 days)
|
||
MIN_PRODUCTS_FOR_CURVE = 5 # Minimum launches to fit a brand curve
|
||
MIN_PRODUCTS_FOR_BRAND_CAT = 10 # Minimum for brand x category curve
|
||
MATURE_VELOCITY_THRESHOLD = 0.1 # units/day to qualify as "mature" vs "dormant"
|
||
MATURE_AGE_DAYS = 60 # days since first_received to be considered mature
|
||
LAUNCH_AGE_DAYS = 14 # days in "launch" phase
|
||
DECAY_AGE_DAYS = 60 # days in "active decay" phase
|
||
EXP_SMOOTHING_WINDOW = 60 # days of history for mature product smoothing
|
||
BATCH_SIZE = 1000 # rows per INSERT batch
|
||
DOW_LOOKBACK_DAYS = 90 # days of order history for day-of-week indices
|
||
MIN_R_SQUARED = 0.1 # curves below this are unreliable (fall back to velocity)
|
||
SEASONAL_LOOKBACK_DAYS = 365 # 12 months of order history for monthly seasonal indices
|
||
MIN_PREORDER_DAYS = 3 # minimum pre-order accumulation days for reliable scaling
|
||
MAX_SMOOTHING_MULTIPLIER = 10 # cap exp smoothing forecast at Nx observed velocity
|
||
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s [%(levelname)s] %(message)s',
|
||
datefmt='%H:%M:%S'
|
||
)
|
||
log = logging.getLogger('forecast')
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Database helpers
|
||
# ---------------------------------------------------------------------------
|
||
def get_connection():
|
||
"""Create a PostgreSQL connection from environment variables."""
|
||
return psycopg2.connect(
|
||
host=os.environ.get('DB_HOST', 'localhost'),
|
||
user=os.environ.get('DB_USER', 'inventory_user'),
|
||
password=os.environ.get('DB_PASSWORD', ''),
|
||
dbname=os.environ.get('DB_NAME', 'inventory_db'),
|
||
port=int(os.environ.get('DB_PORT', 5432)),
|
||
)
|
||
|
||
|
||
def execute_query(conn, sql, params=None):
|
||
"""Execute a query and return a DataFrame."""
|
||
import warnings
|
||
with warnings.catch_warnings():
|
||
warnings.filterwarnings("ignore", message=".*pandas only supports SQLAlchemy.*")
|
||
return pd.read_sql_query(sql, conn, params=params)
|
||
|
||
|
||
def cleanup_stale_runs(conn):
|
||
"""Mark any runs stuck in 'running' status as failed (e.g. from prior crashes)."""
|
||
with conn.cursor() as cur:
|
||
cur.execute("""
|
||
UPDATE forecast_runs
|
||
SET status = 'failed', finished_at = NOW(),
|
||
error_message = 'Stale run cleaned up on engine restart'
|
||
WHERE status = 'running'
|
||
AND started_at < NOW() - INTERVAL '1 hour'
|
||
""")
|
||
cleaned = cur.rowcount
|
||
conn.commit()
|
||
if cleaned > 0:
|
||
log.info(f"Cleaned up {cleaned} stale 'running' forecast run(s)")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Decay curve model: sales(t) = A * exp(-λt) + C
|
||
# ---------------------------------------------------------------------------
|
||
def decay_model(t, amplitude, decay_rate, baseline):
|
||
"""Parametric exponential decay with baseline."""
|
||
return amplitude * np.exp(-decay_rate * t) + baseline
|
||
|
||
|
||
def fit_decay_curve(weekly_medians):
|
||
"""
|
||
Fit the decay model to median weekly sales data.
|
||
|
||
Args:
|
||
weekly_medians: array of median sales per week (index = week number)
|
||
|
||
Returns:
|
||
(amplitude, decay_rate, baseline, r_squared) or None if fit fails
|
||
"""
|
||
weeks = np.arange(len(weekly_medians), dtype=float)
|
||
y = np.array(weekly_medians, dtype=float)
|
||
|
||
# Skip if all zeros or too few points
|
||
if len(y) < 3 or np.max(y) == 0:
|
||
return None
|
||
|
||
# Initial guesses
|
||
a0 = float(np.max(y))
|
||
c0 = float(np.min(y[len(y)//2:])) # baseline from second half
|
||
lam0 = 0.3 # moderate decay
|
||
|
||
try:
|
||
popt, _ = curve_fit(
|
||
decay_model, weeks, y,
|
||
p0=[a0, lam0, c0],
|
||
bounds=([0, 0.01, 0], [a0 * 5, 5.0, a0]),
|
||
maxfev=5000,
|
||
)
|
||
amplitude, decay_rate, baseline = popt
|
||
|
||
# R-squared
|
||
y_pred = decay_model(weeks, *popt)
|
||
ss_res = np.sum((y - y_pred) ** 2)
|
||
ss_tot = np.sum((y - np.mean(y)) ** 2)
|
||
r_sq = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0.0
|
||
|
||
return float(amplitude), float(decay_rate), float(baseline), float(r_sq)
|
||
except (RuntimeError, ValueError) as e:
|
||
log.debug(f"Curve fit failed: {e}")
|
||
return None
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Day-of-week indices
|
||
# ---------------------------------------------------------------------------
|
||
def compute_dow_indices(conn):
|
||
"""
|
||
Compute day-of-week revenue indices from recent order history.
|
||
|
||
Returns a dict mapping ISO weekday (1=Mon ... 7=Sun) to a multiplier
|
||
normalized so they sum to 7.0 (average = 1.0). This means applying them
|
||
preserves the weekly total while reshaping the daily distribution.
|
||
"""
|
||
sql = """
|
||
SELECT
|
||
EXTRACT(ISODOW FROM o.date)::int AS dow,
|
||
SUM(o.price * o.quantity) AS revenue
|
||
FROM orders o
|
||
WHERE o.canceled IS DISTINCT FROM TRUE
|
||
AND o.date >= CURRENT_DATE - INTERVAL '1 day' * %s
|
||
GROUP BY 1
|
||
ORDER BY 1
|
||
"""
|
||
|
||
df = execute_query(conn, sql, [DOW_LOOKBACK_DAYS])
|
||
|
||
if df.empty or len(df) < 7:
|
||
log.warning("Insufficient order data for DOW indices, using flat distribution")
|
||
return {d: 1.0 for d in range(1, 8)}
|
||
|
||
total = df['revenue'].sum()
|
||
avg = total / 7.0
|
||
|
||
indices = {}
|
||
for _, row in df.iterrows():
|
||
dow = int(row['dow'])
|
||
idx = float(row['revenue']) / avg if avg > 0 else 1.0
|
||
indices[dow] = round(idx, 4)
|
||
|
||
# Fill any missing days
|
||
for d in range(1, 8):
|
||
if d not in indices:
|
||
indices[d] = 1.0
|
||
|
||
log.info(f"DOW indices: Mon={indices[1]:.3f} Tue={indices[2]:.3f} Wed={indices[3]:.3f} "
|
||
f"Thu={indices[4]:.3f} Fri={indices[5]:.3f} Sat={indices[6]:.3f} Sun={indices[7]:.3f}")
|
||
return indices
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Monthly seasonal indices
|
||
# ---------------------------------------------------------------------------
|
||
def compute_monthly_seasonal_indices(conn):
|
||
"""
|
||
Compute monthly seasonal indices from recent order revenue.
|
||
|
||
Returns a dict mapping month number (1-12) to a multiplier normalized
|
||
so they average 1.0. Months with above-average revenue get >1, below get <1.
|
||
"""
|
||
sql = """
|
||
SELECT
|
||
EXTRACT(MONTH FROM o.date)::int AS month,
|
||
SUM(o.price * o.quantity) AS revenue
|
||
FROM orders o
|
||
WHERE o.canceled IS DISTINCT FROM TRUE
|
||
AND o.date >= CURRENT_DATE - INTERVAL '1 day' * %s
|
||
GROUP BY 1
|
||
ORDER BY 1
|
||
"""
|
||
|
||
df = execute_query(conn, sql, [SEASONAL_LOOKBACK_DAYS])
|
||
|
||
if df.empty or len(df) < 6:
|
||
log.warning("Insufficient data for seasonal indices, using flat distribution")
|
||
return {m: 1.0 for m in range(1, 13)}
|
||
|
||
total = df['revenue'].sum()
|
||
n_months = len(df)
|
||
avg = total / n_months
|
||
|
||
indices = {}
|
||
for _, row in df.iterrows():
|
||
month = int(row['month'])
|
||
idx = float(row['revenue']) / avg if avg > 0 else 1.0
|
||
indices[month] = round(idx, 4)
|
||
|
||
# Fill any missing months with 1.0
|
||
for m in range(1, 13):
|
||
if m not in indices:
|
||
indices[m] = 1.0
|
||
|
||
present = [f"{m}={indices[m]:.3f}" for m in range(1, 13)]
|
||
log.info(f"Monthly seasonal indices: {', '.join(present)}")
|
||
return indices
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Phase 1: Build brand-category reference curves
|
||
# ---------------------------------------------------------------------------
|
||
DEAL_CATEGORIES = frozenset([
|
||
'Deals', 'Black Friday', 'Week 1', 'Week 2', 'Week 3',
|
||
'28 Off', '5 Dollar Deals', '10 Dollar Deals', 'Fall Sale',
|
||
])
|
||
|
||
|
||
def build_reference_curves(conn):
|
||
"""
|
||
Build decay curves for each brand (and brand x category at every hierarchy level).
|
||
|
||
For category curves, we load each product's full set of category assignments
|
||
(across all hierarchy levels), then fit brand×cat_id curves wherever we have
|
||
enough products. This gives granular curves like "49 and Market × 12x12 Paper Pads"
|
||
alongside coarser fallbacks like "49 and Market × Paper".
|
||
|
||
Returns DataFrame of curves written to brand_lifecycle_curves.
|
||
"""
|
||
log.info("Building reference curves from historical launches...")
|
||
|
||
# Get daily sales aligned by days-since-first-received for recent launches
|
||
# (no category join here — we attach categories separately)
|
||
sales_sql = """
|
||
WITH recent_launches AS (
|
||
SELECT pm.pid, p.brand
|
||
FROM product_metrics pm
|
||
JOIN products p ON p.pid = pm.pid
|
||
WHERE p.visible = true
|
||
AND p.brand IS NOT NULL
|
||
AND pm.date_first_received >= NOW() - INTERVAL '1 day' * %s
|
||
AND pm.date_first_received < NOW() - INTERVAL '14 days'
|
||
),
|
||
daily_sales AS (
|
||
SELECT
|
||
rl.pid, rl.brand,
|
||
dps.snapshot_date,
|
||
COALESCE(dps.units_sold, 0) AS units_sold,
|
||
(dps.snapshot_date - pm.date_first_received::date) AS day_offset
|
||
FROM recent_launches rl
|
||
JOIN product_metrics pm ON pm.pid = rl.pid
|
||
JOIN daily_product_snapshots dps ON dps.pid = rl.pid
|
||
WHERE dps.snapshot_date >= pm.date_first_received::date
|
||
AND dps.snapshot_date < pm.date_first_received::date + INTERVAL '1 week' * %s
|
||
)
|
||
SELECT brand, pid,
|
||
FLOOR(day_offset / 7)::int AS week_num,
|
||
SUM(units_sold) AS weekly_sales
|
||
FROM daily_sales
|
||
WHERE day_offset >= 0
|
||
GROUP BY brand, pid, week_num
|
||
ORDER BY brand, pid, week_num
|
||
"""
|
||
|
||
df = execute_query(conn, sales_sql, [CURVE_HISTORY_DAYS, CURVE_WINDOW_WEEKS])
|
||
if df.empty:
|
||
log.warning("No launch data found for reference curves")
|
||
return pd.DataFrame()
|
||
|
||
log.info(f"Loaded {len(df)} weekly sales records from {df['pid'].nunique()} products across {df['brand'].nunique()} brands")
|
||
|
||
# Load all category assignments for these products (every hierarchy level)
|
||
launch_pids = df['pid'].unique().tolist()
|
||
cat_sql = """
|
||
SELECT pc.pid, ch.cat_id, ch.name AS cat_name, ch.level AS cat_level
|
||
FROM product_categories pc
|
||
JOIN category_hierarchy ch ON ch.cat_id = pc.cat_id
|
||
WHERE pc.pid = ANY(%s)
|
||
AND ch.name NOT IN %s
|
||
ORDER BY pc.pid, ch.level DESC
|
||
"""
|
||
cat_df = execute_query(conn, cat_sql, [launch_pids, tuple(DEAL_CATEGORIES)])
|
||
# Build pid -> list of (cat_id, cat_name, cat_level)
|
||
pid_cats = {}
|
||
for _, row in cat_df.iterrows():
|
||
pid = int(row['pid'])
|
||
if pid not in pid_cats:
|
||
pid_cats[pid] = []
|
||
pid_cats[pid].append((int(row['cat_id']), row['cat_name'], int(row['cat_level'])))
|
||
|
||
# Also get pre-order stats per brand (median pre-order sales AND accumulation window).
|
||
# Uses de-facto preorders: any product that had orders before date_first_received,
|
||
# regardless of the preorder_count flag. This gives us 6000+ completed cycles vs ~19
|
||
# from the explicit flag alone.
|
||
preorder_sql = """
|
||
WITH preorder_stats AS (
|
||
SELECT p.pid, p.brand,
|
||
COALESCE((SELECT SUM(o.quantity) FROM orders o
|
||
WHERE o.pid = p.pid AND o.canceled IS DISTINCT FROM TRUE
|
||
AND o.date < pm.date_first_received), 0) AS preorder_units,
|
||
GREATEST(EXTRACT(DAY FROM pm.date_first_received - MIN(o.date)), 1) AS preorder_days
|
||
FROM products p
|
||
JOIN product_metrics pm ON pm.pid = p.pid
|
||
LEFT JOIN orders o ON o.pid = p.pid AND o.canceled IS DISTINCT FROM TRUE
|
||
AND o.date < pm.date_first_received
|
||
WHERE p.visible = true AND p.brand IS NOT NULL
|
||
AND pm.date_first_received IS NOT NULL
|
||
AND pm.date_first_received >= NOW() - INTERVAL '1 day' * %s
|
||
GROUP BY p.pid, p.brand, pm.date_first_received
|
||
)
|
||
SELECT brand,
|
||
PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY preorder_units) AS median_preorder_sales,
|
||
PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY preorder_days) AS median_preorder_days
|
||
FROM preorder_stats
|
||
WHERE preorder_units > 0
|
||
GROUP BY brand
|
||
HAVING COUNT(*) >= 3
|
||
"""
|
||
|
||
preorder_df = execute_query(conn, preorder_sql, [CURVE_HISTORY_DAYS])
|
||
preorder_map = dict(zip(preorder_df['brand'], preorder_df['median_preorder_sales'])) if not preorder_df.empty else {}
|
||
preorder_days_map = dict(zip(preorder_df['brand'], preorder_df['median_preorder_days'])) if not preorder_df.empty else {}
|
||
|
||
curves = []
|
||
|
||
def _fit_and_append(group_df, brand, cat_id=None, cat_name=None, cat_level=None):
|
||
"""Helper: fit a decay curve for a group and append to curves list."""
|
||
product_count = group_df['pid'].nunique()
|
||
min_products = MIN_PRODUCTS_FOR_CURVE if cat_id is None else MIN_PRODUCTS_FOR_BRAND_CAT
|
||
|
||
if product_count < min_products:
|
||
return False
|
||
|
||
weekly = group_df.groupby('week_num')['weekly_sales'].median()
|
||
if len(weekly) < 4:
|
||
return False
|
||
|
||
full_weeks = weekly.reindex(range(CURVE_WINDOW_WEEKS), fill_value=0.0)
|
||
weekly_arr = full_weeks.values[:CURVE_WINDOW_WEEKS]
|
||
|
||
result = fit_decay_curve(weekly_arr)
|
||
if result is None:
|
||
return False
|
||
|
||
amplitude, decay_rate, baseline, r_sq = result
|
||
|
||
# Quality gate: only store curves above the reliability threshold
|
||
if r_sq < MIN_R_SQUARED:
|
||
return False
|
||
|
||
first_week = group_df[group_df['week_num'] == 0].groupby('pid')['weekly_sales'].sum()
|
||
median_fw = float(first_week.median()) if len(first_week) > 0 else 0.0
|
||
|
||
curves.append({
|
||
'brand': brand,
|
||
'root_category': cat_name, # kept for readability; cat_id is the real key
|
||
'cat_id': cat_id,
|
||
'category_level': cat_level,
|
||
'amplitude': amplitude,
|
||
'decay_rate': decay_rate,
|
||
'baseline': baseline,
|
||
'r_squared': r_sq,
|
||
'sample_size': product_count,
|
||
'median_first_week_sales': median_fw,
|
||
'median_preorder_sales': preorder_map.get(brand),
|
||
'median_preorder_days': preorder_days_map.get(brand),
|
||
})
|
||
return True
|
||
|
||
# 1. Fit brand-level curves (aggregate across all categories)
|
||
for brand, brand_df in df.groupby('brand'):
|
||
_fit_and_append(brand_df, brand)
|
||
|
||
# 2. Fit brand × category curves at every hierarchy level
|
||
# Build a mapping of (brand, cat_id) -> list of pids
|
||
brand_cat_pids = {}
|
||
for pid, cats in pid_cats.items():
|
||
brand_rows = df[df['pid'] == pid]
|
||
if brand_rows.empty:
|
||
continue
|
||
brand = brand_rows.iloc[0]['brand']
|
||
for cat_id, cat_name, cat_level in cats:
|
||
key = (brand, cat_id)
|
||
if key not in brand_cat_pids:
|
||
brand_cat_pids[key] = {'cat_name': cat_name, 'cat_level': cat_level, 'pids': set()}
|
||
brand_cat_pids[key]['pids'].add(pid)
|
||
|
||
cat_curves_fitted = 0
|
||
for (brand, cat_id), info in brand_cat_pids.items():
|
||
group_df = df[(df['brand'] == brand) & (df['pid'].isin(info['pids']))]
|
||
if _fit_and_append(group_df, brand, cat_id=cat_id,
|
||
cat_name=info['cat_name'], cat_level=info['cat_level']):
|
||
cat_curves_fitted += 1
|
||
|
||
curves_df = pd.DataFrame(curves)
|
||
if curves_df.empty:
|
||
log.warning("No curves could be fitted")
|
||
return curves_df
|
||
|
||
# Write to database
|
||
with conn.cursor() as cur:
|
||
cur.execute("TRUNCATE brand_lifecycle_curves")
|
||
for _, row in curves_df.iterrows():
|
||
cur.execute("""
|
||
INSERT INTO brand_lifecycle_curves
|
||
(brand, root_category, cat_id, category_level,
|
||
amplitude, decay_rate, baseline,
|
||
r_squared, sample_size, median_first_week_sales,
|
||
median_preorder_sales, median_preorder_days, computed_at)
|
||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW())
|
||
""", (
|
||
row['brand'],
|
||
None if pd.isna(row.get('root_category')) else row.get('root_category'),
|
||
None if pd.isna(row.get('cat_id')) else int(row['cat_id']),
|
||
None if pd.isna(row.get('category_level')) else int(row['category_level']),
|
||
row['amplitude'], row['decay_rate'], row['baseline'],
|
||
row['r_squared'], row['sample_size'],
|
||
row['median_first_week_sales'],
|
||
None if pd.isna(row.get('median_preorder_sales')) else row.get('median_preorder_sales'),
|
||
None if pd.isna(row.get('median_preorder_days')) else row.get('median_preorder_days'),
|
||
))
|
||
conn.commit()
|
||
|
||
brand_only = curves_df[curves_df['cat_id'].isna()].shape[0]
|
||
cat_total = curves_df[curves_df['cat_id'].notna()].shape[0]
|
||
log.info(f"Wrote {len(curves_df)} reference curves ({cat_total} brand+category across all levels, {brand_only} brand-only)")
|
||
return curves_df
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Phase 2: Classify products and generate forecasts
|
||
# ---------------------------------------------------------------------------
|
||
def load_products(conn):
|
||
"""
|
||
Load all visible products with their metrics for classification.
|
||
|
||
Also loads each product's full category ancestry (all hierarchy levels),
|
||
stored as a list of cat_ids ordered deepest-first for hierarchical curve lookup.
|
||
"""
|
||
sql = """
|
||
SELECT
|
||
pm.pid,
|
||
p.brand,
|
||
pm.current_price,
|
||
pm.current_stock,
|
||
pm.sales_velocity_daily,
|
||
pm.sales_30d,
|
||
pm.date_first_received,
|
||
pm.date_last_sold,
|
||
p.preorder_count,
|
||
COALESCE(p.baskets, 0) AS baskets,
|
||
EXTRACT(DAY FROM NOW() - pm.date_first_received) AS age_days
|
||
FROM product_metrics pm
|
||
JOIN products p ON p.pid = pm.pid
|
||
WHERE p.visible = true
|
||
"""
|
||
products = execute_query(conn, sql)
|
||
|
||
# Load category assignments for all products (every hierarchy level, deepest first)
|
||
cat_sql = """
|
||
SELECT pc.pid, ch.cat_id, ch.level AS cat_level
|
||
FROM product_categories pc
|
||
JOIN category_hierarchy ch ON ch.cat_id = pc.cat_id
|
||
WHERE ch.name NOT IN %s
|
||
ORDER BY pc.pid, ch.level DESC
|
||
"""
|
||
cat_df = execute_query(conn, cat_sql, [tuple(DEAL_CATEGORIES)])
|
||
|
||
# Build pid -> [cat_id, ...] ordered deepest-first
|
||
pid_cat_ids = {}
|
||
for _, row in cat_df.iterrows():
|
||
pid = int(row['pid'])
|
||
if pid not in pid_cat_ids:
|
||
pid_cat_ids[pid] = []
|
||
pid_cat_ids[pid].append(int(row['cat_id']))
|
||
|
||
# Attach category list to each product row as a Python object column
|
||
products['cat_ids'] = products['pid'].apply(lambda p: pid_cat_ids.get(int(p), []))
|
||
|
||
log.info(f"Loaded {len(products)} products, "
|
||
f"{sum(1 for c in products['cat_ids'] if len(c) > 0)}/{len(products)} with category data")
|
||
return products
|
||
|
||
|
||
def classify_phase(row):
|
||
"""Classify a product into its lifecycle phase."""
|
||
preorder = (row.get('preorder_count') or 0) > 0
|
||
age = row.get('age_days')
|
||
velocity = row.get('sales_velocity_daily') or 0
|
||
first_received = row.get('date_first_received')
|
||
|
||
# Pre-order: has preorder_count and either not received or very recently received
|
||
if preorder and (first_received is None or (age is not None and age <= LAUNCH_AGE_DAYS)):
|
||
return 'preorder'
|
||
|
||
# No first_received date — can't determine lifecycle
|
||
if first_received is None or age is None:
|
||
if velocity > MATURE_VELOCITY_THRESHOLD:
|
||
return 'mature'
|
||
if velocity > 0:
|
||
return 'slow_mover'
|
||
return 'dormant'
|
||
|
||
if age <= LAUNCH_AGE_DAYS:
|
||
return 'launch'
|
||
elif age <= DECAY_AGE_DAYS:
|
||
if velocity > 0:
|
||
return 'decay'
|
||
return 'dormant'
|
||
else:
|
||
if velocity > MATURE_VELOCITY_THRESHOLD:
|
||
return 'mature'
|
||
if velocity > 0:
|
||
return 'slow_mover'
|
||
return 'dormant'
|
||
|
||
|
||
def get_curve_for_product(product, curves_df):
|
||
"""
|
||
Look up the best matching reference curve for a product.
|
||
|
||
Uses hierarchical category fallback: tries the product's deepest category
|
||
first, walks up the hierarchy to coarser categories, then falls back to
|
||
brand-only. This ensures e.g. "49 and Market × 12x12 Paper Pads" is
|
||
preferred over "49 and Market × Paper Crafts" when available.
|
||
|
||
Skips curves with R² below MIN_R_SQUARED (unreliable fits).
|
||
Returns (amplitude, decay_rate, baseline, median_first_week, median_preorder, median_preorder_days) or None.
|
||
"""
|
||
brand = product.get('brand')
|
||
|
||
if brand is None or curves_df.empty:
|
||
return None
|
||
|
||
# Filter to this brand's reliable curves once
|
||
brand_curves = curves_df[
|
||
(curves_df['brand'] == brand)
|
||
& (curves_df['r_squared'] >= MIN_R_SQUARED)
|
||
]
|
||
if brand_curves.empty:
|
||
return None
|
||
|
||
def _extract(row):
|
||
return (
|
||
float(row['amplitude']),
|
||
float(row['decay_rate']),
|
||
float(row['baseline']),
|
||
float(row['median_first_week_sales'] or 1),
|
||
float(row['median_preorder_sales']) if pd.notna(row.get('median_preorder_sales')) else None,
|
||
float(row['median_preorder_days']) if pd.notna(row.get('median_preorder_days')) else None,
|
||
)
|
||
|
||
# Try each category from deepest to shallowest
|
||
cat_ids = product.get('cat_ids') or []
|
||
for cat_id in cat_ids:
|
||
match = brand_curves[brand_curves['cat_id'] == cat_id]
|
||
if not match.empty:
|
||
return _extract(match.iloc[0])
|
||
|
||
# Fall back to brand-only curve (cat_id is NaN/None)
|
||
brand_only = brand_curves[brand_curves['cat_id'].isna()]
|
||
if brand_only.empty:
|
||
return None
|
||
|
||
return _extract(brand_only.iloc[0])
|
||
|
||
|
||
def forecast_from_curve(curve_params, scale_factor, age_days, horizon_days):
|
||
"""
|
||
Generate daily forecast from a scaled decay curve.
|
||
|
||
The scale factor is applied only to the decay envelope, NOT the baseline.
|
||
This prevents hot products from getting inflated long-tail forecasts.
|
||
|
||
Formula: daily_value = (A/7) * exp(-λ * t_weeks) * scale + (C/7)
|
||
|
||
Args:
|
||
curve_params: (amplitude, decay_rate, baseline, ...)
|
||
scale_factor: multiplier for the decay envelope
|
||
age_days: current product age in days
|
||
horizon_days: how many days to forecast
|
||
|
||
Returns:
|
||
array of daily forecast values
|
||
"""
|
||
amplitude, decay_rate, baseline = curve_params[:3]
|
||
# The curve is in weekly units; convert to daily
|
||
daily_amp = amplitude / 7.0
|
||
daily_baseline = baseline / 7.0
|
||
|
||
forecasts = []
|
||
for d in range(horizon_days):
|
||
t_weeks = (age_days + d) / 7.0
|
||
daily_value = daily_amp * np.exp(-decay_rate * t_weeks) * scale_factor + daily_baseline
|
||
forecasts.append(max(0.0, daily_value))
|
||
|
||
return np.array(forecasts)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Batch data loading (eliminates N+1 per-product queries)
|
||
# ---------------------------------------------------------------------------
|
||
def batch_load_product_data(conn, products):
|
||
"""
|
||
Batch-load all per-product data needed for forecasting in a few queries
|
||
instead of one query per product.
|
||
|
||
Returns dict with keys:
|
||
preorder_sales: {pid: units} — pre-order units (before first received)
|
||
launch_sales: {pid: units} — first 14 days of sales
|
||
decay_velocity: {pid: avg} — recent 30-day daily average
|
||
mature_history: {pid: DataFrame} — daily sales history for SES
|
||
"""
|
||
data = {
|
||
'preorder_sales': {},
|
||
'preorder_days': {},
|
||
'launch_sales': {},
|
||
'decay_velocity': {},
|
||
'mature_history': {},
|
||
}
|
||
|
||
# Pre-order sales: orders placed BEFORE first received date
|
||
# Also compute the number of days pre-orders accumulated over (for daily-rate normalization)
|
||
preorder_pids = products[products['phase'] == 'preorder']['pid'].tolist()
|
||
if preorder_pids:
|
||
sql = """
|
||
SELECT o.pid,
|
||
COALESCE(SUM(o.quantity), 0) AS preorder_units,
|
||
GREATEST(EXTRACT(DAY FROM NOW() - MIN(o.date)), 1) AS preorder_days
|
||
FROM orders o
|
||
LEFT JOIN product_metrics pm ON pm.pid = o.pid
|
||
WHERE o.pid = ANY(%s)
|
||
AND o.canceled IS DISTINCT FROM TRUE
|
||
AND (pm.date_first_received IS NULL OR o.date < pm.date_first_received)
|
||
GROUP BY o.pid
|
||
"""
|
||
df = execute_query(conn, sql, [preorder_pids])
|
||
for _, row in df.iterrows():
|
||
data['preorder_sales'][int(row['pid'])] = float(row['preorder_units'])
|
||
data['preorder_days'][int(row['pid'])] = float(row['preorder_days'])
|
||
log.info(f"Batch loaded pre-order sales for {len(data['preorder_sales'])}/{len(preorder_pids)} preorder products")
|
||
|
||
# Launch sales: first 14 days after first received
|
||
launch_pids = products[products['phase'] == 'launch']['pid'].tolist()
|
||
if launch_pids:
|
||
sql = """
|
||
SELECT dps.pid, COALESCE(SUM(dps.units_sold), 0) AS total_sold
|
||
FROM daily_product_snapshots dps
|
||
JOIN product_metrics pm ON pm.pid = dps.pid
|
||
WHERE dps.pid = ANY(%s)
|
||
AND dps.snapshot_date >= pm.date_first_received::date
|
||
AND dps.snapshot_date < pm.date_first_received::date + INTERVAL '14 days'
|
||
GROUP BY dps.pid
|
||
"""
|
||
df = execute_query(conn, sql, [launch_pids])
|
||
for _, row in df.iterrows():
|
||
data['launch_sales'][int(row['pid'])] = float(row['total_sold'])
|
||
log.info(f"Batch loaded launch sales for {len(data['launch_sales'])}/{len(launch_pids)} launch products")
|
||
|
||
# Decay recent velocity: average daily sales over last 30 days
|
||
decay_pids = products[products['phase'] == 'decay']['pid'].tolist()
|
||
if decay_pids:
|
||
sql = """
|
||
SELECT dps.pid, AVG(COALESCE(dps.units_sold, 0)) AS avg_daily
|
||
FROM daily_product_snapshots dps
|
||
WHERE dps.pid = ANY(%s)
|
||
AND dps.snapshot_date >= CURRENT_DATE - INTERVAL '30 days'
|
||
GROUP BY dps.pid
|
||
"""
|
||
df = execute_query(conn, sql, [decay_pids])
|
||
for _, row in df.iterrows():
|
||
data['decay_velocity'][int(row['pid'])] = float(row['avg_daily'])
|
||
log.info(f"Batch loaded decay velocity for {len(data['decay_velocity'])}/{len(decay_pids)} decay products")
|
||
|
||
# Mature daily history: full time series for exponential smoothing
|
||
mature_pids = products[products['phase'] == 'mature']['pid'].tolist()
|
||
if mature_pids:
|
||
sql = """
|
||
SELECT dps.pid, dps.snapshot_date, COALESCE(dps.units_sold, 0) AS units_sold
|
||
FROM daily_product_snapshots dps
|
||
WHERE dps.pid = ANY(%s)
|
||
AND dps.snapshot_date >= CURRENT_DATE - INTERVAL '1 day' * %s
|
||
ORDER BY dps.pid, dps.snapshot_date
|
||
"""
|
||
df = execute_query(conn, sql, [mature_pids, EXP_SMOOTHING_WINDOW])
|
||
for pid, group in df.groupby('pid'):
|
||
data['mature_history'][int(pid)] = group.copy()
|
||
log.info(f"Batch loaded history for {len(data['mature_history'])}/{len(mature_pids)} mature products")
|
||
|
||
return data
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Per-product scale factor computation
|
||
# ---------------------------------------------------------------------------
|
||
def compute_scale_factor(phase, product, curve_info, batch_data):
|
||
"""
|
||
Compute the per-product scale factor for the brand curve.
|
||
|
||
The scale factor captures how much more/less this product sells compared
|
||
to the brand average. It's applied to the decay envelope only (not baseline).
|
||
"""
|
||
if curve_info is None:
|
||
return 1.0
|
||
|
||
pid = int(product['pid'])
|
||
amplitude, decay_rate, baseline, median_fw, median_preorder, med_preorder_days = curve_info
|
||
|
||
if phase == 'preorder':
|
||
preorder_units = batch_data['preorder_sales'].get(pid, 0)
|
||
preorder_days = batch_data['preorder_days'].get(pid, 1)
|
||
baskets = product.get('baskets') or 0
|
||
|
||
# Too few days of accumulation → noisy signal, use brand average
|
||
if preorder_days < MIN_PREORDER_DAYS and preorder_units > 0:
|
||
scale = 1.0
|
||
return max(0.1, min(scale, 5.0))
|
||
|
||
# Use order units as primary signal; fall back to baskets if no orders
|
||
demand_signal = preorder_units if preorder_units > 0 else baskets
|
||
signal_days = preorder_days if preorder_units > 0 else max(preorder_days, 14)
|
||
|
||
# Normalize to daily rate before comparing to brand median daily rate.
|
||
# Use the brand's stored median pre-order window for the denominator
|
||
# (not the current product's signal_days) to avoid systematic bias.
|
||
demand_daily = demand_signal / max(signal_days, 1)
|
||
if median_preorder and median_preorder > 0:
|
||
brand_preorder_window = max(med_preorder_days or signal_days, 1)
|
||
median_preorder_daily = median_preorder / brand_preorder_window
|
||
scale = demand_daily / median_preorder_daily
|
||
elif median_fw > 0 and demand_daily > 0:
|
||
median_fw_daily = median_fw / 7.0
|
||
scale = demand_daily / median_fw_daily
|
||
else:
|
||
scale = 1.0
|
||
|
||
elif phase == 'launch':
|
||
actual_sold = batch_data['launch_sales'].get(pid, 0)
|
||
age = max(0, product.get('age_days') or 0)
|
||
if median_fw > 0 and actual_sold > 0:
|
||
days_observed = min(age, 14)
|
||
if days_observed > 0:
|
||
projected_first_week = (actual_sold / days_observed) * 7
|
||
scale = projected_first_week / median_fw
|
||
else:
|
||
scale = 1.0
|
||
else:
|
||
scale = 1.0
|
||
|
||
elif phase == 'decay':
|
||
actual_velocity = batch_data['decay_velocity'].get(pid, 0)
|
||
age = max(0, product.get('age_days') or 0)
|
||
t_weeks = age / 7.0
|
||
# With baseline fix: value = (A/7)*exp(-λt)*scale + C/7
|
||
# Solve for scale: scale = (actual - C/7) / ((A/7)*exp(-λt))
|
||
decay_part = (amplitude / 7.0) * np.exp(-decay_rate * t_weeks)
|
||
# Use a higher floor for the denominator at high ages to prevent
|
||
# extreme scale factors when the decay envelope is nearly zero
|
||
min_decay = max(0.01, amplitude / 70.0) # at least 10% of week-1 daily value
|
||
if decay_part > min_decay and actual_velocity > 0:
|
||
scale = (actual_velocity - baseline / 7.0) / decay_part
|
||
elif actual_velocity > 0:
|
||
scale = 1.0
|
||
else:
|
||
scale = 1.0
|
||
|
||
else:
|
||
scale = 1.0
|
||
|
||
# Clamp to avoid extreme values — tighter for preorder since the signal
|
||
# is noisier (pre-orders accumulate differently than post-launch sales)
|
||
max_scale = 5.0 if phase == 'preorder' else 8.0
|
||
return max(0.1, min(scale, max_scale))
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Mature product forecast (Holt's double exponential smoothing)
|
||
# ---------------------------------------------------------------------------
|
||
def forecast_mature(product, history_df):
|
||
"""
|
||
Forecast for a mature/evergreen product using Holt's linear trend method
|
||
on recent daily sales history. Holt's adds a trend component over SES,
|
||
so it naturally pulls the forecast back down after a sales spike instead
|
||
of persisting the inflated level.
|
||
|
||
Falls back to SES then flat velocity on failure.
|
||
"""
|
||
pid = int(product['pid'])
|
||
velocity = product.get('sales_velocity_daily') or 0
|
||
|
||
if history_df is None or history_df.empty or len(history_df) < 7:
|
||
# Not enough data — flat velocity
|
||
return np.full(FORECAST_HORIZON_DAYS, velocity)
|
||
|
||
# Fill date gaps with 0 sales (days where product had no snapshot = no sales)
|
||
hist = history_df.copy()
|
||
hist['snapshot_date'] = pd.to_datetime(hist['snapshot_date'])
|
||
hist = hist.set_index('snapshot_date').resample('D').sum().fillna(0)
|
||
series = hist['units_sold'].values.astype(float)
|
||
|
||
# Need at least 2 non-zero values for smoothing
|
||
if np.count_nonzero(series) < 2:
|
||
return np.full(FORECAST_HORIZON_DAYS, velocity)
|
||
|
||
# Cap: prevent runaway forecasts from one-time spikes.
|
||
# Use the higher of 30d velocity or the observed mean as the baseline,
|
||
# so sustained increases are respected.
|
||
observed_mean = float(np.mean(series))
|
||
cap = max(velocity, observed_mean) * MAX_SMOOTHING_MULTIPLIER
|
||
|
||
try:
|
||
# Holt's with damped trend: the phi parameter dampens the trend over
|
||
# the horizon so forecasts converge to a level instead of extrapolating
|
||
# a linear trend indefinitely.
|
||
model = Holt(series, initialization_method='estimated', damped_trend=True)
|
||
fit = model.fit(optimized=True)
|
||
forecast = fit.forecast(FORECAST_HORIZON_DAYS)
|
||
forecast = np.clip(forecast, 0, cap)
|
||
return forecast
|
||
except Exception:
|
||
# Fall back to SES if Holt's fails (e.g. insufficient data points)
|
||
try:
|
||
model = SimpleExpSmoothing(series, initialization_method='estimated')
|
||
fit = model.fit(optimized=True)
|
||
forecast = fit.forecast(FORECAST_HORIZON_DAYS)
|
||
forecast = np.clip(forecast, 0, cap)
|
||
return forecast
|
||
except Exception as e:
|
||
log.debug(f"ExpSmoothing failed for pid {pid}: {e}")
|
||
return np.full(FORECAST_HORIZON_DAYS, velocity)
|
||
|
||
|
||
def forecast_dormant():
|
||
"""Dormant products get near-zero forecast."""
|
||
return np.zeros(FORECAST_HORIZON_DAYS)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Accuracy-driven confidence margins
|
||
# ---------------------------------------------------------------------------
|
||
DEFAULT_MARGINS = {
|
||
'preorder': 0.4,
|
||
'launch': 0.35,
|
||
'decay': 0.3,
|
||
'mature': 0.35,
|
||
'slow_mover': 0.5,
|
||
'dormant': 0.5,
|
||
}
|
||
MIN_MARGIN = 0.15 # intervals shouldn't be tighter than ±15%
|
||
MAX_MARGIN = 1.0 # intervals shouldn't exceed ±100%
|
||
|
||
|
||
def load_accuracy_margins(conn):
|
||
"""
|
||
Load per-phase WMAPE from the most recent forecast accuracy run.
|
||
Returns a dict of phase -> base_margin, falling back to DEFAULT_MARGINS.
|
||
WMAPE is already a ratio (e.g. 1.7 = 170%), which we use directly as margin.
|
||
"""
|
||
margins = dict(DEFAULT_MARGINS)
|
||
try:
|
||
df = execute_query(conn, """
|
||
SELECT fa.dimension_value AS phase, fa.wmape
|
||
FROM forecast_accuracy fa
|
||
JOIN forecast_runs fr ON fr.id = fa.run_id
|
||
WHERE fa.metric_type = 'by_phase'
|
||
AND fr.status IN ('completed', 'backfill')
|
||
AND fa.wmape IS NOT NULL
|
||
ORDER BY fr.finished_at DESC
|
||
""")
|
||
if df.empty:
|
||
log.info("No accuracy data available, using default confidence margins")
|
||
return margins
|
||
|
||
# Take the most recent run's values (they appear first due to ORDER BY)
|
||
seen = set()
|
||
for _, row in df.iterrows():
|
||
phase = row['phase']
|
||
if phase not in seen:
|
||
wmape = float(row['wmape'])
|
||
margins[phase] = max(MIN_MARGIN, min(wmape, MAX_MARGIN))
|
||
seen.add(phase)
|
||
|
||
log.info(f"Loaded accuracy-based margins: {', '.join(f'{k}={v:.2f}' for k, v in margins.items())}")
|
||
except Exception as e:
|
||
log.warning(f"Could not load accuracy margins, using defaults: {e}")
|
||
|
||
return margins
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Main orchestration
|
||
# ---------------------------------------------------------------------------
|
||
FLUSH_EVERY_PRODUCTS = 5000 # Flush forecast rows to DB every N products
|
||
|
||
def generate_all_forecasts(conn, curves_df, dow_indices, monthly_indices=None,
|
||
accuracy_margins=None):
|
||
"""Classify all products, batch-load data, generate and stream-write forecasts.
|
||
|
||
Writes forecast rows to product_forecasts in chunks to avoid accumulating
|
||
millions of rows in memory (37K products × 90 days = 3.3M rows).
|
||
"""
|
||
if monthly_indices is None:
|
||
monthly_indices = {m: 1.0 for m in range(1, 13)}
|
||
if accuracy_margins is None:
|
||
accuracy_margins = dict(DEFAULT_MARGINS)
|
||
log.info("Loading products for classification...")
|
||
products = load_products(conn)
|
||
log.info(f"Loaded {len(products)} visible products")
|
||
|
||
# Classify each product
|
||
products['phase'] = products.apply(classify_phase, axis=1)
|
||
phase_counts = products['phase'].value_counts().to_dict()
|
||
log.info(f"Phase distribution: {phase_counts}")
|
||
|
||
# Batch-load per-product data (replaces per-product queries)
|
||
log.info("Batch loading product data...")
|
||
batch_data = batch_load_product_data(conn, products)
|
||
|
||
today = date.today()
|
||
forecast_dates = [today + timedelta(days=i) for i in range(FORECAST_HORIZON_DAYS)]
|
||
|
||
# Pre-compute DOW and seasonal multipliers for each forecast date
|
||
dow_multipliers = [dow_indices.get(d.isoweekday(), 1.0) for d in forecast_dates]
|
||
seasonal_multipliers = [monthly_indices.get(d.month, 1.0) for d in forecast_dates]
|
||
|
||
# TRUNCATE before streaming writes
|
||
with conn.cursor() as cur:
|
||
cur.execute("TRUNCATE product_forecasts")
|
||
conn.commit()
|
||
|
||
buffer = []
|
||
methods = {}
|
||
processed = 0
|
||
errors = 0
|
||
total_rows = 0
|
||
|
||
insert_sql = """
|
||
INSERT INTO product_forecasts
|
||
(pid, forecast_date, forecast_units, forecast_revenue,
|
||
lifecycle_phase, forecast_method, confidence_lower,
|
||
confidence_upper)
|
||
VALUES %s
|
||
"""
|
||
|
||
def flush_buffer():
|
||
nonlocal buffer, total_rows
|
||
if not buffer:
|
||
return
|
||
with conn.cursor() as cur:
|
||
psycopg2.extras.execute_values(
|
||
cur, insert_sql, buffer,
|
||
template="(%s, %s, %s, %s, %s, %s, %s, %s)",
|
||
page_size=BATCH_SIZE,
|
||
)
|
||
conn.commit()
|
||
total_rows += len(buffer)
|
||
buffer = []
|
||
|
||
for _, product in products.iterrows():
|
||
pid = int(product['pid'])
|
||
phase = product['phase']
|
||
price = float(product['current_price'] or 0)
|
||
age = max(0, product.get('age_days') or 0)
|
||
|
||
try:
|
||
curve_info = get_curve_for_product(product, curves_df)
|
||
|
||
if phase in ('preorder', 'launch'):
|
||
if curve_info:
|
||
scale = compute_scale_factor(phase, product, curve_info, batch_data)
|
||
forecasts = forecast_from_curve(curve_info, scale, age, FORECAST_HORIZON_DAYS)
|
||
method = 'lifecycle_curve'
|
||
else:
|
||
# No reliable curve — fall back to velocity if available
|
||
velocity = product.get('sales_velocity_daily') or 0
|
||
if velocity > 0:
|
||
forecasts = np.full(FORECAST_HORIZON_DAYS, velocity)
|
||
method = 'velocity'
|
||
else:
|
||
forecasts = forecast_dormant()
|
||
method = 'zero'
|
||
|
||
elif phase == 'decay':
|
||
if curve_info:
|
||
scale = compute_scale_factor(phase, product, curve_info, batch_data)
|
||
forecasts = forecast_from_curve(curve_info, scale, age, FORECAST_HORIZON_DAYS)
|
||
method = 'lifecycle_curve'
|
||
else:
|
||
velocity = product.get('sales_velocity_daily') or 0
|
||
forecasts = np.full(FORECAST_HORIZON_DAYS, velocity)
|
||
method = 'velocity'
|
||
|
||
elif phase == 'mature':
|
||
history = batch_data['mature_history'].get(pid)
|
||
forecasts = forecast_mature(product, history)
|
||
method = 'exp_smoothing'
|
||
|
||
elif phase == 'slow_mover':
|
||
velocity = product.get('sales_velocity_daily') or 0
|
||
forecasts = np.full(FORECAST_HORIZON_DAYS, velocity)
|
||
method = 'velocity'
|
||
|
||
else: # dormant
|
||
forecasts = forecast_dormant()
|
||
method = 'zero'
|
||
|
||
# Confidence interval: use accuracy-calibrated margins per phase
|
||
base_margin = accuracy_margins.get(phase, 0.5)
|
||
|
||
for i, d in enumerate(forecast_dates):
|
||
base_units = float(forecasts[i]) if i < len(forecasts) else 0.0
|
||
# Apply day-of-week and seasonal adjustments
|
||
units = base_units * dow_multipliers[i] * seasonal_multipliers[i]
|
||
# Widen confidence interval as horizon grows: day 0 = base, day 89 ≈ +50% wider
|
||
horizon_factor = 1.0 + 0.5 * (i / max(FORECAST_HORIZON_DAYS - 1, 1))
|
||
margin = base_margin * horizon_factor
|
||
buffer.append((
|
||
pid, d,
|
||
round(units, 2),
|
||
round(units * price, 4),
|
||
phase, method,
|
||
round(units * max(1 - margin, 0), 2),
|
||
round(units * (1 + margin), 2),
|
||
))
|
||
|
||
methods[method] = methods.get(method, 0) + 1
|
||
|
||
except Exception as e:
|
||
log.warning(f"Error forecasting pid {pid}: {e}")
|
||
errors += 1
|
||
# Write zero forecast so we have complete coverage
|
||
for d in forecast_dates:
|
||
buffer.append((pid, d, 0, 0, phase, 'zero', 0, 0))
|
||
|
||
processed += 1
|
||
if processed % FLUSH_EVERY_PRODUCTS == 0:
|
||
flush_buffer()
|
||
log.info(f" Processed {processed}/{len(products)} products ({total_rows} rows written)...")
|
||
|
||
# Final flush
|
||
flush_buffer()
|
||
|
||
log.info(f"Forecast generation complete. {processed} products, {errors} errors, {total_rows} rows")
|
||
log.info(f"Method distribution: {methods}")
|
||
|
||
return total_rows, processed, phase_counts
|
||
|
||
|
||
def archive_forecasts(conn, run_id):
|
||
"""
|
||
Copy current product_forecasts into history before they get replaced.
|
||
Only archives forecast rows for dates that have already passed,
|
||
so we can later compare them against actuals.
|
||
"""
|
||
with conn.cursor() as cur:
|
||
# Ensure history table exists
|
||
cur.execute("""
|
||
CREATE TABLE IF NOT EXISTS product_forecasts_history (
|
||
run_id INT NOT NULL,
|
||
pid BIGINT NOT NULL,
|
||
forecast_date DATE NOT NULL,
|
||
forecast_units NUMERIC(10,2),
|
||
forecast_revenue NUMERIC(14,4),
|
||
lifecycle_phase TEXT,
|
||
forecast_method TEXT,
|
||
confidence_lower NUMERIC(10,2),
|
||
confidence_upper NUMERIC(10,2),
|
||
generated_at TIMESTAMP,
|
||
PRIMARY KEY (run_id, pid, forecast_date)
|
||
)
|
||
""")
|
||
cur.execute("CREATE INDEX IF NOT EXISTS idx_pfh_date ON product_forecasts_history(forecast_date)")
|
||
cur.execute("CREATE INDEX IF NOT EXISTS idx_pfh_pid_date ON product_forecasts_history(pid, forecast_date)")
|
||
|
||
# Find the previous completed run (whose forecasts are still in product_forecasts)
|
||
cur.execute("""
|
||
SELECT id FROM forecast_runs
|
||
WHERE status = 'completed'
|
||
ORDER BY finished_at DESC
|
||
LIMIT 1
|
||
""")
|
||
prev_run = cur.fetchone()
|
||
if prev_run is None:
|
||
log.info("No previous completed run found, skipping archive")
|
||
conn.commit()
|
||
return 0
|
||
|
||
prev_run_id = prev_run[0]
|
||
|
||
# Archive only past-date forecasts (where actuals now exist)
|
||
cur.execute("""
|
||
INSERT INTO product_forecasts_history
|
||
(run_id, pid, forecast_date, forecast_units, forecast_revenue,
|
||
lifecycle_phase, forecast_method, confidence_lower, confidence_upper, generated_at)
|
||
SELECT %s, pid, forecast_date, forecast_units, forecast_revenue,
|
||
lifecycle_phase, forecast_method, confidence_lower, confidence_upper, generated_at
|
||
FROM product_forecasts
|
||
WHERE forecast_date < CURRENT_DATE
|
||
ON CONFLICT (run_id, pid, forecast_date) DO NOTHING
|
||
""", (prev_run_id,))
|
||
|
||
archived = cur.rowcount
|
||
conn.commit()
|
||
|
||
if archived > 0:
|
||
log.info(f"Archived {archived} historical forecast rows from run {prev_run_id}")
|
||
else:
|
||
log.info("No past-date forecasts to archive")
|
||
|
||
# Prune old history (keep 90 days for accuracy analysis)
|
||
cur.execute("DELETE FROM product_forecasts_history WHERE forecast_date < CURRENT_DATE - INTERVAL '90 days'")
|
||
pruned = cur.rowcount
|
||
if pruned > 0:
|
||
log.info(f"Pruned {pruned} old history rows (>90 days)")
|
||
conn.commit()
|
||
|
||
return archived
|
||
|
||
|
||
def compute_accuracy(conn, run_id):
|
||
"""
|
||
Compute forecast accuracy metrics from archived history vs. actual sales.
|
||
|
||
Joins product_forecasts_history with daily_product_snapshots on
|
||
(pid, forecast_date = snapshot_date) to compare forecasted vs. actual units.
|
||
|
||
Stores results in forecast_accuracy table, broken down by:
|
||
- overall: single aggregate row
|
||
- by_phase: per lifecycle phase
|
||
- by_lead_time: bucketed by how far ahead the forecast was
|
||
- by_method: per forecast method
|
||
- daily: per forecast_date (for trend charts)
|
||
"""
|
||
with conn.cursor() as cur:
|
||
# Ensure accuracy table exists
|
||
cur.execute("""
|
||
CREATE TABLE IF NOT EXISTS forecast_accuracy (
|
||
run_id INT NOT NULL,
|
||
metric_type TEXT NOT NULL,
|
||
dimension_value TEXT NOT NULL,
|
||
sample_size INT,
|
||
total_actual_units NUMERIC(12,2),
|
||
total_forecast_units NUMERIC(12,2),
|
||
mae NUMERIC(10,4),
|
||
wmape NUMERIC(10,4),
|
||
bias NUMERIC(10,4),
|
||
rmse NUMERIC(10,4),
|
||
computed_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
||
PRIMARY KEY (run_id, metric_type, dimension_value)
|
||
)
|
||
""")
|
||
conn.commit()
|
||
|
||
# Check if we have any history to analyze
|
||
cur.execute("SELECT COUNT(*) FROM product_forecasts_history")
|
||
history_count = cur.fetchone()[0]
|
||
if history_count == 0:
|
||
log.info("No forecast history available for accuracy computation")
|
||
return
|
||
|
||
# For each (pid, forecast_date) pair, keep only the most recent run's
|
||
# forecast row. This prevents double-counting when multiple runs have
|
||
# archived forecasts for the same product×date combination.
|
||
accuracy_cte = """
|
||
WITH ranked_history AS (
|
||
SELECT
|
||
pfh.*,
|
||
fr.started_at,
|
||
ROW_NUMBER() OVER (
|
||
PARTITION BY pfh.pid, pfh.forecast_date
|
||
ORDER BY fr.started_at DESC
|
||
) AS rn
|
||
FROM product_forecasts_history pfh
|
||
JOIN forecast_runs fr ON fr.id = pfh.run_id
|
||
),
|
||
accuracy AS (
|
||
SELECT
|
||
rh.lifecycle_phase,
|
||
rh.forecast_method,
|
||
rh.forecast_date,
|
||
(rh.forecast_date - rh.started_at::date) AS lead_days,
|
||
rh.forecast_units,
|
||
COALESCE(dps.units_sold, 0) AS actual_units,
|
||
(rh.forecast_units - COALESCE(dps.units_sold, 0)) AS error,
|
||
ABS(rh.forecast_units - COALESCE(dps.units_sold, 0)) AS abs_error
|
||
FROM ranked_history rh
|
||
LEFT JOIN daily_product_snapshots dps
|
||
ON dps.pid = rh.pid AND dps.snapshot_date = rh.forecast_date
|
||
WHERE rh.rn = 1
|
||
AND NOT (rh.forecast_units = 0 AND COALESCE(dps.units_sold, 0) = 0)
|
||
)
|
||
"""
|
||
|
||
# Compute and insert metrics for each dimension
|
||
dimensions = {
|
||
'overall': "SELECT 'all' AS dim",
|
||
'by_phase': "SELECT DISTINCT lifecycle_phase AS dim FROM accuracy",
|
||
'by_lead_time': """
|
||
SELECT DISTINCT
|
||
CASE
|
||
WHEN lead_days BETWEEN 0 AND 6 THEN '1-7d'
|
||
WHEN lead_days BETWEEN 7 AND 13 THEN '8-14d'
|
||
WHEN lead_days BETWEEN 14 AND 29 THEN '15-30d'
|
||
WHEN lead_days BETWEEN 30 AND 59 THEN '31-60d'
|
||
ELSE '61-90d'
|
||
END AS dim
|
||
FROM accuracy
|
||
""",
|
||
'by_method': "SELECT DISTINCT forecast_method AS dim FROM accuracy",
|
||
'daily': "SELECT DISTINCT forecast_date::text AS dim FROM accuracy",
|
||
}
|
||
|
||
filter_clauses = {
|
||
'overall': "lifecycle_phase != 'dormant'",
|
||
'by_phase': "lifecycle_phase = dims.dim",
|
||
'by_lead_time': """
|
||
CASE
|
||
WHEN lead_days BETWEEN 0 AND 6 THEN '1-7d'
|
||
WHEN lead_days BETWEEN 7 AND 13 THEN '8-14d'
|
||
WHEN lead_days BETWEEN 14 AND 29 THEN '15-30d'
|
||
WHEN lead_days BETWEEN 30 AND 59 THEN '31-60d'
|
||
ELSE '61-90d'
|
||
END = dims.dim
|
||
""",
|
||
'by_method': "forecast_method = dims.dim",
|
||
'daily': "forecast_date::text = dims.dim",
|
||
}
|
||
|
||
total_inserted = 0
|
||
|
||
for metric_type, dim_query in dimensions.items():
|
||
filter_clause = filter_clauses[metric_type]
|
||
|
||
sql = f"""
|
||
{accuracy_cte},
|
||
dims AS ({dim_query})
|
||
SELECT
|
||
dims.dim,
|
||
COUNT(*) AS sample_size,
|
||
COALESCE(SUM(a.actual_units), 0) AS total_actual,
|
||
COALESCE(SUM(a.forecast_units), 0) AS total_forecast,
|
||
AVG(a.abs_error) AS mae,
|
||
CASE WHEN SUM(a.actual_units) > 0
|
||
THEN SUM(a.abs_error) / SUM(a.actual_units)
|
||
ELSE NULL END AS wmape,
|
||
AVG(a.error) AS bias,
|
||
SQRT(AVG(POWER(a.error, 2))) AS rmse
|
||
FROM dims
|
||
CROSS JOIN accuracy a
|
||
WHERE {filter_clause}
|
||
GROUP BY dims.dim
|
||
"""
|
||
|
||
cur.execute(sql)
|
||
rows = cur.fetchall()
|
||
|
||
for row in rows:
|
||
dim_val, sample_size, total_actual, total_forecast, mae, wmape, bias, rmse = row
|
||
cur.execute("""
|
||
INSERT INTO forecast_accuracy
|
||
(run_id, metric_type, dimension_value, sample_size,
|
||
total_actual_units, total_forecast_units, mae, wmape, bias, rmse)
|
||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||
ON CONFLICT (run_id, metric_type, dimension_value)
|
||
DO UPDATE SET
|
||
sample_size = EXCLUDED.sample_size,
|
||
total_actual_units = EXCLUDED.total_actual_units,
|
||
total_forecast_units = EXCLUDED.total_forecast_units,
|
||
mae = EXCLUDED.mae, wmape = EXCLUDED.wmape,
|
||
bias = EXCLUDED.bias, rmse = EXCLUDED.rmse,
|
||
computed_at = NOW()
|
||
""", (run_id, metric_type, dim_val, sample_size,
|
||
float(total_actual), float(total_forecast),
|
||
float(mae) if mae is not None else None,
|
||
float(wmape) if wmape is not None else None,
|
||
float(bias) if bias is not None else None,
|
||
float(rmse) if rmse is not None else None))
|
||
total_inserted += 1
|
||
|
||
conn.commit()
|
||
|
||
# Prune old accuracy data (keep 90 days of runs + any in-progress run)
|
||
cur.execute("""
|
||
DELETE FROM forecast_accuracy
|
||
WHERE run_id NOT IN (
|
||
SELECT id FROM forecast_runs
|
||
WHERE finished_at >= NOW() - INTERVAL '90 days'
|
||
OR finished_at IS NULL
|
||
)
|
||
""")
|
||
pruned = cur.rowcount
|
||
conn.commit()
|
||
|
||
log.info(f"Accuracy metrics: {total_inserted} rows computed"
|
||
+ (f", {pruned} old rows pruned" if pruned > 0 else ""))
|
||
|
||
|
||
def backfill_accuracy_data(conn, backfill_days=30):
|
||
"""
|
||
Generate retroactive forecast data for the past N days to bootstrap
|
||
accuracy metrics. Uses the current brand curves with per-product scaling
|
||
to approximate what the model would have predicted for each past day,
|
||
then stores results in product_forecasts_history for comparison against
|
||
actual snapshots.
|
||
|
||
This is a model backtest (in-sample), not true out-of-sample accuracy,
|
||
but provides much better initial estimates than unscaled brand curves.
|
||
"""
|
||
backfill_start_time = time.time()
|
||
log.info(f"Backfilling {backfill_days} days of accuracy data with per-product scaling...")
|
||
|
||
# Load DOW indices
|
||
dow_indices = compute_dow_indices(conn)
|
||
|
||
# Load brand curves (already fitted)
|
||
curves_df = execute_query(conn, """
|
||
SELECT brand, root_category, cat_id, category_level,
|
||
amplitude, decay_rate, baseline,
|
||
r_squared, median_first_week_sales, median_preorder_sales,
|
||
median_preorder_days
|
||
FROM brand_lifecycle_curves
|
||
""")
|
||
|
||
# Load products
|
||
products = load_products(conn)
|
||
products['phase'] = products.apply(classify_phase, axis=1)
|
||
|
||
# Skip dormant — they forecast 0 and are filtered from accuracy anyway
|
||
active = products[products['phase'] != 'dormant'].copy()
|
||
log.info(f"Backfilling for {len(active)} non-dormant products "
|
||
f"(skipping {len(products) - len(active)} dormant)")
|
||
|
||
# Batch load product data for per-product scaling
|
||
batch_data = batch_load_product_data(conn, active)
|
||
|
||
today = date.today()
|
||
backfill_start = today - timedelta(days=backfill_days)
|
||
|
||
# Create a synthetic run entry
|
||
with conn.cursor() as cur:
|
||
cur.execute("""
|
||
INSERT INTO forecast_runs
|
||
(started_at, finished_at, status, products_forecast,
|
||
phase_counts, error_message)
|
||
VALUES (%s, NOW(), 'backfill', %s, %s, %s)
|
||
RETURNING id
|
||
""", (
|
||
backfill_start,
|
||
len(active),
|
||
json.dumps({'backfill_days': backfill_days}),
|
||
f'Model backtest: {backfill_days} days with per-product scaling',
|
||
))
|
||
backfill_run_id = cur.fetchone()[0]
|
||
conn.commit()
|
||
log.info(f"Created backfill run {backfill_run_id} "
|
||
f"(simulated start: {backfill_start})")
|
||
|
||
# Generate retroactive forecasts
|
||
all_rows = []
|
||
backfill_dates = [backfill_start + timedelta(days=i)
|
||
for i in range(backfill_days)]
|
||
|
||
for _, product in active.iterrows():
|
||
pid = int(product['pid'])
|
||
price = float(product['current_price'] or 0)
|
||
current_age = product.get('age_days')
|
||
velocity = float(product.get('sales_velocity_daily') or 0)
|
||
phase = product['phase']
|
||
|
||
curve_info = get_curve_for_product(product, curves_df)
|
||
|
||
# Compute per-product scale factor (same logic as main forecast)
|
||
scale = compute_scale_factor(phase, product, curve_info, batch_data)
|
||
|
||
for forecast_date in backfill_dates:
|
||
# How many days ago was this date?
|
||
days_ago = (today - forecast_date).days
|
||
# Product's age on that date
|
||
past_age = (current_age - days_ago) if current_age is not None else None
|
||
|
||
if past_age is not None and past_age < 0:
|
||
# Product didn't exist yet on this date
|
||
continue
|
||
|
||
# Determine what phase the product was likely in
|
||
if past_age is not None:
|
||
if past_age <= LAUNCH_AGE_DAYS:
|
||
past_phase = 'launch'
|
||
elif past_age <= DECAY_AGE_DAYS:
|
||
past_phase = 'decay'
|
||
else:
|
||
past_phase = phase # use current classification
|
||
else:
|
||
past_phase = phase
|
||
|
||
# Compute forecast value for this date
|
||
if past_phase in ('launch', 'decay', 'preorder') and curve_info:
|
||
amplitude, decay_rate, baseline = curve_info[:3]
|
||
age_for_calc = max(0, past_age or 0)
|
||
t_weeks = age_for_calc / 7.0
|
||
# Use corrected formula: scale only the decay envelope, not the baseline
|
||
daily_value = (amplitude / 7.0) * np.exp(-decay_rate * t_weeks) * scale + (baseline / 7.0)
|
||
units = max(0.0, float(daily_value))
|
||
method = 'lifecycle_curve'
|
||
elif past_phase == 'mature' and velocity > 0:
|
||
units = velocity
|
||
method = 'exp_smoothing'
|
||
else:
|
||
units = velocity if velocity > 0 else 0.0
|
||
method = 'velocity' if velocity > 0 else 'zero'
|
||
|
||
# Apply DOW multiplier
|
||
dow_mult = dow_indices.get(forecast_date.isoweekday(), 1.0)
|
||
units *= dow_mult
|
||
|
||
if units == 0 and method == 'zero':
|
||
continue # skip zero-zero rows
|
||
|
||
revenue = units * price
|
||
margin = 0.3 if method == 'lifecycle_curve' else 0.4
|
||
|
||
all_rows.append((
|
||
backfill_run_id, pid, forecast_date,
|
||
round(float(units), 2),
|
||
round(float(revenue), 4),
|
||
past_phase, method,
|
||
round(float(units * (1 - margin)), 2),
|
||
round(float(units * (1 + margin)), 2),
|
||
backfill_start, # generated_at
|
||
))
|
||
|
||
log.info(f"Generated {len(all_rows)} backfill forecast rows")
|
||
|
||
# Write to history table
|
||
if all_rows:
|
||
with conn.cursor() as cur:
|
||
sql = """
|
||
INSERT INTO product_forecasts_history
|
||
(run_id, pid, forecast_date, forecast_units, forecast_revenue,
|
||
lifecycle_phase, forecast_method, confidence_lower,
|
||
confidence_upper, generated_at)
|
||
VALUES %s
|
||
ON CONFLICT (run_id, pid, forecast_date) DO NOTHING
|
||
"""
|
||
psycopg2.extras.execute_values(
|
||
cur, sql, all_rows,
|
||
template="(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)",
|
||
page_size=BATCH_SIZE,
|
||
)
|
||
conn.commit()
|
||
log.info(f"Wrote {len(all_rows)} rows to product_forecasts_history")
|
||
|
||
# Now compute accuracy on the backfilled data
|
||
compute_accuracy(conn, backfill_run_id)
|
||
|
||
# Mark the backfill run as completed
|
||
backfill_duration = time.time() - backfill_start_time
|
||
with conn.cursor() as cur:
|
||
cur.execute("""
|
||
UPDATE forecast_runs
|
||
SET finished_at = NOW(), status = 'backfill',
|
||
duration_seconds = %s
|
||
WHERE id = %s
|
||
""", (round(backfill_duration, 2), backfill_run_id))
|
||
conn.commit()
|
||
|
||
log.info(f"Backfill complete in {backfill_duration:.1f}s")
|
||
return backfill_run_id
|
||
|
||
|
||
def main():
|
||
start_time = time.time()
|
||
|
||
conn = get_connection()
|
||
|
||
# Clean up any stale "running" entries from prior crashes
|
||
cleanup_stale_runs(conn)
|
||
|
||
# Check for --backfill flag
|
||
if '--backfill' in sys.argv:
|
||
idx = sys.argv.index('--backfill')
|
||
days = int(sys.argv[idx + 1]) if idx + 1 < len(sys.argv) else 30
|
||
log.info("=" * 60)
|
||
log.info(f"Backfill mode: {days} days")
|
||
log.info("=" * 60)
|
||
try:
|
||
backfill_accuracy_data(conn, days)
|
||
finally:
|
||
conn.close()
|
||
return
|
||
|
||
log.info("=" * 60)
|
||
log.info("Forecast Engine starting")
|
||
log.info("=" * 60)
|
||
|
||
run_id = None
|
||
|
||
try:
|
||
# Record run start
|
||
with conn.cursor() as cur:
|
||
cur.execute(
|
||
"INSERT INTO forecast_runs (started_at, status) VALUES (NOW(), 'running') RETURNING id"
|
||
)
|
||
run_id = cur.fetchone()[0]
|
||
conn.commit()
|
||
|
||
# Phase 0: Compute day-of-week and monthly seasonal indices
|
||
dow_indices = compute_dow_indices(conn)
|
||
monthly_indices = compute_monthly_seasonal_indices(conn)
|
||
|
||
# Phase 1: Build reference curves
|
||
curves_df = build_reference_curves(conn)
|
||
|
||
# Phase 2: Archive historical forecasts (before TRUNCATE in generation)
|
||
archive_forecasts(conn, run_id)
|
||
|
||
# Phase 3: Compute accuracy from archived history vs actuals
|
||
compute_accuracy(conn, run_id)
|
||
|
||
# Phase 3b: Load accuracy-calibrated confidence margins
|
||
accuracy_margins = load_accuracy_margins(conn)
|
||
|
||
# Phase 4: Generate and stream-write forecasts (TRUNCATE + chunked INSERT)
|
||
total_rows, products_forecast, phase_counts = generate_all_forecasts(
|
||
conn, curves_df, dow_indices, monthly_indices, accuracy_margins
|
||
)
|
||
|
||
duration = time.time() - start_time
|
||
|
||
# Record run completion (include DOW indices in metadata)
|
||
with conn.cursor() as cur:
|
||
cur.execute("""
|
||
UPDATE forecast_runs
|
||
SET finished_at = NOW(), status = 'completed',
|
||
products_forecast = %s, phase_counts = %s,
|
||
curve_count = %s, duration_seconds = %s
|
||
WHERE id = %s
|
||
""", (
|
||
products_forecast,
|
||
json.dumps({
|
||
**phase_counts,
|
||
'_dow_indices': {str(k): v for k, v in dow_indices.items()},
|
||
'_seasonal_indices': {str(k): v for k, v in monthly_indices.items()},
|
||
}),
|
||
len(curves_df) if not curves_df.empty else 0,
|
||
round(duration, 2),
|
||
run_id,
|
||
))
|
||
conn.commit()
|
||
|
||
log.info("=" * 60)
|
||
log.info(f"Forecast complete in {duration:.1f}s")
|
||
log.info(f" Products: {products_forecast}")
|
||
log.info(f" Curves: {len(curves_df) if not curves_df.empty else 0}")
|
||
log.info(f" Phases: {phase_counts}")
|
||
log.info(f" Rows written: {total_rows}")
|
||
log.info("=" * 60)
|
||
|
||
except Exception as e:
|
||
duration = time.time() - start_time
|
||
log.error(f"Forecast engine failed: {e}", exc_info=True)
|
||
|
||
if run_id:
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute("""
|
||
UPDATE forecast_runs
|
||
SET finished_at = NOW(), status = 'failed',
|
||
error_message = %s, duration_seconds = %s
|
||
WHERE id = %s
|
||
""", (str(e), round(duration, 2), run_id))
|
||
conn.commit()
|
||
except Exception:
|
||
pass
|
||
|
||
sys.exit(1)
|
||
finally:
|
||
conn.close()
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|