Files
inventory/inventory-server/scripts/forecast/forecast_engine.py

1620 lines
63 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Lifecycle-Aware Forecast Engine
Generates 90-day per-product daily sales forecasts using analogous lifecycle
curves learned from historical brand/category launch patterns.
Usage:
python forecast_engine.py
python forecast_engine.py --backfill 30
Environment variables (from .env):
DB_HOST, DB_USER, DB_PASSWORD, DB_NAME, DB_PORT (default 5432)
"""
import os
import sys
import json
import time
import logging
from datetime import datetime, date, timedelta
import numpy as np
import pandas as pd
import psycopg2
import psycopg2.extras
import psycopg2.extensions
from scipy.optimize import curve_fit
from statsmodels.tsa.holtwinters import SimpleExpSmoothing, Holt
# Register numpy type adapters so psycopg2 can serialize them to SQL
psycopg2.extensions.register_adapter(np.float64, lambda x: psycopg2.extensions.AsIs(float(x)))
psycopg2.extensions.register_adapter(np.float32, lambda x: psycopg2.extensions.AsIs(float(x)))
psycopg2.extensions.register_adapter(np.int64, lambda x: psycopg2.extensions.AsIs(int(x)))
psycopg2.extensions.register_adapter(np.int32, lambda x: psycopg2.extensions.AsIs(int(x)))
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
FORECAST_HORIZON_DAYS = 90
CURVE_HISTORY_DAYS = 730 # 2 years of launches to build reference curves
CURVE_WINDOW_WEEKS = 13 # Track decay for 13 weeks (91 days)
MIN_PRODUCTS_FOR_CURVE = 5 # Minimum launches to fit a brand curve
MIN_PRODUCTS_FOR_BRAND_CAT = 10 # Minimum for brand x category curve
MATURE_VELOCITY_THRESHOLD = 0.1 # units/day to qualify as "mature" vs "dormant"
MATURE_AGE_DAYS = 60 # days since first_received to be considered mature
LAUNCH_AGE_DAYS = 14 # days in "launch" phase
DECAY_AGE_DAYS = 60 # days in "active decay" phase
EXP_SMOOTHING_WINDOW = 60 # days of history for mature product smoothing
BATCH_SIZE = 1000 # rows per INSERT batch
DOW_LOOKBACK_DAYS = 90 # days of order history for day-of-week indices
MIN_R_SQUARED = 0.1 # curves below this are unreliable (fall back to velocity)
SEASONAL_LOOKBACK_DAYS = 365 # 12 months of order history for monthly seasonal indices
MIN_PREORDER_DAYS = 3 # minimum pre-order accumulation days for reliable scaling
MAX_SMOOTHING_MULTIPLIER = 10 # cap exp smoothing forecast at Nx observed velocity
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(message)s',
datefmt='%H:%M:%S'
)
log = logging.getLogger('forecast')
# ---------------------------------------------------------------------------
# Database helpers
# ---------------------------------------------------------------------------
def get_connection():
"""Create a PostgreSQL connection from environment variables."""
return psycopg2.connect(
host=os.environ.get('DB_HOST', 'localhost'),
user=os.environ.get('DB_USER', 'inventory_user'),
password=os.environ.get('DB_PASSWORD', ''),
dbname=os.environ.get('DB_NAME', 'inventory_db'),
port=int(os.environ.get('DB_PORT', 5432)),
)
def execute_query(conn, sql, params=None):
"""Execute a query and return a DataFrame."""
import warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message=".*pandas only supports SQLAlchemy.*")
return pd.read_sql_query(sql, conn, params=params)
def cleanup_stale_runs(conn):
"""Mark any runs stuck in 'running' status as failed (e.g. from prior crashes)."""
with conn.cursor() as cur:
cur.execute("""
UPDATE forecast_runs
SET status = 'failed', finished_at = NOW(),
error_message = 'Stale run cleaned up on engine restart'
WHERE status = 'running'
AND started_at < NOW() - INTERVAL '1 hour'
""")
cleaned = cur.rowcount
conn.commit()
if cleaned > 0:
log.info(f"Cleaned up {cleaned} stale 'running' forecast run(s)")
# ---------------------------------------------------------------------------
# Decay curve model: sales(t) = A * exp(-λt) + C
# ---------------------------------------------------------------------------
def decay_model(t, amplitude, decay_rate, baseline):
"""Parametric exponential decay with baseline."""
return amplitude * np.exp(-decay_rate * t) + baseline
def fit_decay_curve(weekly_medians):
"""
Fit the decay model to median weekly sales data.
Args:
weekly_medians: array of median sales per week (index = week number)
Returns:
(amplitude, decay_rate, baseline, r_squared) or None if fit fails
"""
weeks = np.arange(len(weekly_medians), dtype=float)
y = np.array(weekly_medians, dtype=float)
# Skip if all zeros or too few points
if len(y) < 3 or np.max(y) == 0:
return None
# Initial guesses
a0 = float(np.max(y))
c0 = float(np.min(y[len(y)//2:])) # baseline from second half
lam0 = 0.3 # moderate decay
try:
popt, _ = curve_fit(
decay_model, weeks, y,
p0=[a0, lam0, c0],
bounds=([0, 0.01, 0], [a0 * 5, 5.0, a0]),
maxfev=5000,
)
amplitude, decay_rate, baseline = popt
# R-squared
y_pred = decay_model(weeks, *popt)
ss_res = np.sum((y - y_pred) ** 2)
ss_tot = np.sum((y - np.mean(y)) ** 2)
r_sq = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0.0
return float(amplitude), float(decay_rate), float(baseline), float(r_sq)
except (RuntimeError, ValueError) as e:
log.debug(f"Curve fit failed: {e}")
return None
# ---------------------------------------------------------------------------
# Day-of-week indices
# ---------------------------------------------------------------------------
def compute_dow_indices(conn):
"""
Compute day-of-week revenue indices from recent order history.
Returns a dict mapping ISO weekday (1=Mon ... 7=Sun) to a multiplier
normalized so they sum to 7.0 (average = 1.0). This means applying them
preserves the weekly total while reshaping the daily distribution.
"""
sql = """
SELECT
EXTRACT(ISODOW FROM o.date)::int AS dow,
SUM(o.price * o.quantity) AS revenue
FROM orders o
WHERE o.canceled IS DISTINCT FROM TRUE
AND o.date >= CURRENT_DATE - INTERVAL '1 day' * %s
GROUP BY 1
ORDER BY 1
"""
df = execute_query(conn, sql, [DOW_LOOKBACK_DAYS])
if df.empty or len(df) < 7:
log.warning("Insufficient order data for DOW indices, using flat distribution")
return {d: 1.0 for d in range(1, 8)}
total = df['revenue'].sum()
avg = total / 7.0
indices = {}
for _, row in df.iterrows():
dow = int(row['dow'])
idx = float(row['revenue']) / avg if avg > 0 else 1.0
indices[dow] = round(idx, 4)
# Fill any missing days
for d in range(1, 8):
if d not in indices:
indices[d] = 1.0
log.info(f"DOW indices: Mon={indices[1]:.3f} Tue={indices[2]:.3f} Wed={indices[3]:.3f} "
f"Thu={indices[4]:.3f} Fri={indices[5]:.3f} Sat={indices[6]:.3f} Sun={indices[7]:.3f}")
return indices
# ---------------------------------------------------------------------------
# Monthly seasonal indices
# ---------------------------------------------------------------------------
def compute_monthly_seasonal_indices(conn):
"""
Compute monthly seasonal indices from recent order revenue.
Returns a dict mapping month number (1-12) to a multiplier normalized
so they average 1.0. Months with above-average revenue get >1, below get <1.
"""
sql = """
SELECT
EXTRACT(MONTH FROM o.date)::int AS month,
SUM(o.price * o.quantity) AS revenue
FROM orders o
WHERE o.canceled IS DISTINCT FROM TRUE
AND o.date >= CURRENT_DATE - INTERVAL '1 day' * %s
GROUP BY 1
ORDER BY 1
"""
df = execute_query(conn, sql, [SEASONAL_LOOKBACK_DAYS])
if df.empty or len(df) < 6:
log.warning("Insufficient data for seasonal indices, using flat distribution")
return {m: 1.0 for m in range(1, 13)}
total = df['revenue'].sum()
n_months = len(df)
avg = total / n_months
indices = {}
for _, row in df.iterrows():
month = int(row['month'])
idx = float(row['revenue']) / avg if avg > 0 else 1.0
indices[month] = round(idx, 4)
# Fill any missing months with 1.0
for m in range(1, 13):
if m not in indices:
indices[m] = 1.0
present = [f"{m}={indices[m]:.3f}" for m in range(1, 13)]
log.info(f"Monthly seasonal indices: {', '.join(present)}")
return indices
# ---------------------------------------------------------------------------
# Phase 1: Build brand-category reference curves
# ---------------------------------------------------------------------------
DEAL_CATEGORIES = frozenset([
'Deals', 'Black Friday', 'Week 1', 'Week 2', 'Week 3',
'28 Off', '5 Dollar Deals', '10 Dollar Deals', 'Fall Sale',
])
def build_reference_curves(conn):
"""
Build decay curves for each brand (and brand x category at every hierarchy level).
For category curves, we load each product's full set of category assignments
(across all hierarchy levels), then fit brand×cat_id curves wherever we have
enough products. This gives granular curves like "49 and Market × 12x12 Paper Pads"
alongside coarser fallbacks like "49 and Market × Paper".
Returns DataFrame of curves written to brand_lifecycle_curves.
"""
log.info("Building reference curves from historical launches...")
# Get daily sales aligned by days-since-first-received for recent launches
# (no category join here — we attach categories separately)
sales_sql = """
WITH recent_launches AS (
SELECT pm.pid, p.brand
FROM product_metrics pm
JOIN products p ON p.pid = pm.pid
WHERE p.visible = true
AND p.brand IS NOT NULL
AND pm.date_first_received >= NOW() - INTERVAL '1 day' * %s
AND pm.date_first_received < NOW() - INTERVAL '14 days'
),
daily_sales AS (
SELECT
rl.pid, rl.brand,
dps.snapshot_date,
COALESCE(dps.units_sold, 0) AS units_sold,
(dps.snapshot_date - pm.date_first_received::date) AS day_offset
FROM recent_launches rl
JOIN product_metrics pm ON pm.pid = rl.pid
JOIN daily_product_snapshots dps ON dps.pid = rl.pid
WHERE dps.snapshot_date >= pm.date_first_received::date
AND dps.snapshot_date < pm.date_first_received::date + INTERVAL '1 week' * %s
)
SELECT brand, pid,
FLOOR(day_offset / 7)::int AS week_num,
SUM(units_sold) AS weekly_sales
FROM daily_sales
WHERE day_offset >= 0
GROUP BY brand, pid, week_num
ORDER BY brand, pid, week_num
"""
df = execute_query(conn, sales_sql, [CURVE_HISTORY_DAYS, CURVE_WINDOW_WEEKS])
if df.empty:
log.warning("No launch data found for reference curves")
return pd.DataFrame()
log.info(f"Loaded {len(df)} weekly sales records from {df['pid'].nunique()} products across {df['brand'].nunique()} brands")
# Load all category assignments for these products (every hierarchy level)
launch_pids = df['pid'].unique().tolist()
cat_sql = """
SELECT pc.pid, ch.cat_id, ch.name AS cat_name, ch.level AS cat_level
FROM product_categories pc
JOIN category_hierarchy ch ON ch.cat_id = pc.cat_id
WHERE pc.pid = ANY(%s)
AND ch.name NOT IN %s
ORDER BY pc.pid, ch.level DESC
"""
cat_df = execute_query(conn, cat_sql, [launch_pids, tuple(DEAL_CATEGORIES)])
# Build pid -> list of (cat_id, cat_name, cat_level)
pid_cats = {}
for _, row in cat_df.iterrows():
pid = int(row['pid'])
if pid not in pid_cats:
pid_cats[pid] = []
pid_cats[pid].append((int(row['cat_id']), row['cat_name'], int(row['cat_level'])))
# Also get pre-order stats per brand (median pre-order sales AND accumulation window).
# Uses de-facto preorders: any product that had orders before date_first_received,
# regardless of the preorder_count flag. This gives us 6000+ completed cycles vs ~19
# from the explicit flag alone.
preorder_sql = """
WITH preorder_stats AS (
SELECT p.pid, p.brand,
COALESCE((SELECT SUM(o.quantity) FROM orders o
WHERE o.pid = p.pid AND o.canceled IS DISTINCT FROM TRUE
AND o.date < pm.date_first_received), 0) AS preorder_units,
GREATEST(EXTRACT(DAY FROM pm.date_first_received - MIN(o.date)), 1) AS preorder_days
FROM products p
JOIN product_metrics pm ON pm.pid = p.pid
LEFT JOIN orders o ON o.pid = p.pid AND o.canceled IS DISTINCT FROM TRUE
AND o.date < pm.date_first_received
WHERE p.visible = true AND p.brand IS NOT NULL
AND pm.date_first_received IS NOT NULL
AND pm.date_first_received >= NOW() - INTERVAL '1 day' * %s
GROUP BY p.pid, p.brand, pm.date_first_received
)
SELECT brand,
PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY preorder_units) AS median_preorder_sales,
PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY preorder_days) AS median_preorder_days
FROM preorder_stats
WHERE preorder_units > 0
GROUP BY brand
HAVING COUNT(*) >= 3
"""
preorder_df = execute_query(conn, preorder_sql, [CURVE_HISTORY_DAYS])
preorder_map = dict(zip(preorder_df['brand'], preorder_df['median_preorder_sales'])) if not preorder_df.empty else {}
preorder_days_map = dict(zip(preorder_df['brand'], preorder_df['median_preorder_days'])) if not preorder_df.empty else {}
curves = []
def _fit_and_append(group_df, brand, cat_id=None, cat_name=None, cat_level=None):
"""Helper: fit a decay curve for a group and append to curves list."""
product_count = group_df['pid'].nunique()
min_products = MIN_PRODUCTS_FOR_CURVE if cat_id is None else MIN_PRODUCTS_FOR_BRAND_CAT
if product_count < min_products:
return False
weekly = group_df.groupby('week_num')['weekly_sales'].median()
if len(weekly) < 4:
return False
full_weeks = weekly.reindex(range(CURVE_WINDOW_WEEKS), fill_value=0.0)
weekly_arr = full_weeks.values[:CURVE_WINDOW_WEEKS]
result = fit_decay_curve(weekly_arr)
if result is None:
return False
amplitude, decay_rate, baseline, r_sq = result
# Quality gate: only store curves above the reliability threshold
if r_sq < MIN_R_SQUARED:
return False
first_week = group_df[group_df['week_num'] == 0].groupby('pid')['weekly_sales'].sum()
median_fw = float(first_week.median()) if len(first_week) > 0 else 0.0
curves.append({
'brand': brand,
'root_category': cat_name, # kept for readability; cat_id is the real key
'cat_id': cat_id,
'category_level': cat_level,
'amplitude': amplitude,
'decay_rate': decay_rate,
'baseline': baseline,
'r_squared': r_sq,
'sample_size': product_count,
'median_first_week_sales': median_fw,
'median_preorder_sales': preorder_map.get(brand),
'median_preorder_days': preorder_days_map.get(brand),
})
return True
# 1. Fit brand-level curves (aggregate across all categories)
for brand, brand_df in df.groupby('brand'):
_fit_and_append(brand_df, brand)
# 2. Fit brand × category curves at every hierarchy level
# Build a mapping of (brand, cat_id) -> list of pids
brand_cat_pids = {}
for pid, cats in pid_cats.items():
brand_rows = df[df['pid'] == pid]
if brand_rows.empty:
continue
brand = brand_rows.iloc[0]['brand']
for cat_id, cat_name, cat_level in cats:
key = (brand, cat_id)
if key not in brand_cat_pids:
brand_cat_pids[key] = {'cat_name': cat_name, 'cat_level': cat_level, 'pids': set()}
brand_cat_pids[key]['pids'].add(pid)
cat_curves_fitted = 0
for (brand, cat_id), info in brand_cat_pids.items():
group_df = df[(df['brand'] == brand) & (df['pid'].isin(info['pids']))]
if _fit_and_append(group_df, brand, cat_id=cat_id,
cat_name=info['cat_name'], cat_level=info['cat_level']):
cat_curves_fitted += 1
curves_df = pd.DataFrame(curves)
if curves_df.empty:
log.warning("No curves could be fitted")
return curves_df
# Write to database
with conn.cursor() as cur:
cur.execute("TRUNCATE brand_lifecycle_curves")
for _, row in curves_df.iterrows():
cur.execute("""
INSERT INTO brand_lifecycle_curves
(brand, root_category, cat_id, category_level,
amplitude, decay_rate, baseline,
r_squared, sample_size, median_first_week_sales,
median_preorder_sales, median_preorder_days, computed_at)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW())
""", (
row['brand'],
None if pd.isna(row.get('root_category')) else row.get('root_category'),
None if pd.isna(row.get('cat_id')) else int(row['cat_id']),
None if pd.isna(row.get('category_level')) else int(row['category_level']),
row['amplitude'], row['decay_rate'], row['baseline'],
row['r_squared'], row['sample_size'],
row['median_first_week_sales'],
None if pd.isna(row.get('median_preorder_sales')) else row.get('median_preorder_sales'),
None if pd.isna(row.get('median_preorder_days')) else row.get('median_preorder_days'),
))
conn.commit()
brand_only = curves_df[curves_df['cat_id'].isna()].shape[0]
cat_total = curves_df[curves_df['cat_id'].notna()].shape[0]
log.info(f"Wrote {len(curves_df)} reference curves ({cat_total} brand+category across all levels, {brand_only} brand-only)")
return curves_df
# ---------------------------------------------------------------------------
# Phase 2: Classify products and generate forecasts
# ---------------------------------------------------------------------------
def load_products(conn):
"""
Load all visible products with their metrics for classification.
Also loads each product's full category ancestry (all hierarchy levels),
stored as a list of cat_ids ordered deepest-first for hierarchical curve lookup.
"""
sql = """
SELECT
pm.pid,
p.brand,
pm.current_price,
pm.current_stock,
pm.sales_velocity_daily,
pm.sales_30d,
pm.date_first_received,
pm.date_last_sold,
p.preorder_count,
COALESCE(p.baskets, 0) AS baskets,
EXTRACT(DAY FROM NOW() - pm.date_first_received) AS age_days
FROM product_metrics pm
JOIN products p ON p.pid = pm.pid
WHERE p.visible = true
"""
products = execute_query(conn, sql)
# Load category assignments for all products (every hierarchy level, deepest first)
cat_sql = """
SELECT pc.pid, ch.cat_id, ch.level AS cat_level
FROM product_categories pc
JOIN category_hierarchy ch ON ch.cat_id = pc.cat_id
WHERE ch.name NOT IN %s
ORDER BY pc.pid, ch.level DESC
"""
cat_df = execute_query(conn, cat_sql, [tuple(DEAL_CATEGORIES)])
# Build pid -> [cat_id, ...] ordered deepest-first
pid_cat_ids = {}
for _, row in cat_df.iterrows():
pid = int(row['pid'])
if pid not in pid_cat_ids:
pid_cat_ids[pid] = []
pid_cat_ids[pid].append(int(row['cat_id']))
# Attach category list to each product row as a Python object column
products['cat_ids'] = products['pid'].apply(lambda p: pid_cat_ids.get(int(p), []))
log.info(f"Loaded {len(products)} products, "
f"{sum(1 for c in products['cat_ids'] if len(c) > 0)}/{len(products)} with category data")
return products
def classify_phase(row):
"""Classify a product into its lifecycle phase."""
preorder = (row.get('preorder_count') or 0) > 0
age = row.get('age_days')
velocity = row.get('sales_velocity_daily') or 0
first_received = row.get('date_first_received')
# Pre-order: has preorder_count and either not received or very recently received
if preorder and (first_received is None or (age is not None and age <= LAUNCH_AGE_DAYS)):
return 'preorder'
# No first_received date — can't determine lifecycle
if first_received is None or age is None:
if velocity > MATURE_VELOCITY_THRESHOLD:
return 'mature'
if velocity > 0:
return 'slow_mover'
return 'dormant'
if age <= LAUNCH_AGE_DAYS:
return 'launch'
elif age <= DECAY_AGE_DAYS:
if velocity > 0:
return 'decay'
return 'dormant'
else:
if velocity > MATURE_VELOCITY_THRESHOLD:
return 'mature'
if velocity > 0:
return 'slow_mover'
return 'dormant'
def get_curve_for_product(product, curves_df):
"""
Look up the best matching reference curve for a product.
Uses hierarchical category fallback: tries the product's deepest category
first, walks up the hierarchy to coarser categories, then falls back to
brand-only. This ensures e.g. "49 and Market × 12x12 Paper Pads" is
preferred over "49 and Market × Paper Crafts" when available.
Skips curves with R² below MIN_R_SQUARED (unreliable fits).
Returns (amplitude, decay_rate, baseline, median_first_week, median_preorder, median_preorder_days) or None.
"""
brand = product.get('brand')
if brand is None or curves_df.empty:
return None
# Filter to this brand's reliable curves once
brand_curves = curves_df[
(curves_df['brand'] == brand)
& (curves_df['r_squared'] >= MIN_R_SQUARED)
]
if brand_curves.empty:
return None
def _extract(row):
return (
float(row['amplitude']),
float(row['decay_rate']),
float(row['baseline']),
float(row['median_first_week_sales'] or 1),
float(row['median_preorder_sales']) if pd.notna(row.get('median_preorder_sales')) else None,
float(row['median_preorder_days']) if pd.notna(row.get('median_preorder_days')) else None,
)
# Try each category from deepest to shallowest
cat_ids = product.get('cat_ids') or []
for cat_id in cat_ids:
match = brand_curves[brand_curves['cat_id'] == cat_id]
if not match.empty:
return _extract(match.iloc[0])
# Fall back to brand-only curve (cat_id is NaN/None)
brand_only = brand_curves[brand_curves['cat_id'].isna()]
if brand_only.empty:
return None
return _extract(brand_only.iloc[0])
def forecast_from_curve(curve_params, scale_factor, age_days, horizon_days):
"""
Generate daily forecast from a scaled decay curve.
The scale factor is applied only to the decay envelope, NOT the baseline.
This prevents hot products from getting inflated long-tail forecasts.
Formula: daily_value = (A/7) * exp(-λ * t_weeks) * scale + (C/7)
Args:
curve_params: (amplitude, decay_rate, baseline, ...)
scale_factor: multiplier for the decay envelope
age_days: current product age in days
horizon_days: how many days to forecast
Returns:
array of daily forecast values
"""
amplitude, decay_rate, baseline = curve_params[:3]
# The curve is in weekly units; convert to daily
daily_amp = amplitude / 7.0
daily_baseline = baseline / 7.0
forecasts = []
for d in range(horizon_days):
t_weeks = (age_days + d) / 7.0
daily_value = daily_amp * np.exp(-decay_rate * t_weeks) * scale_factor + daily_baseline
forecasts.append(max(0.0, daily_value))
return np.array(forecasts)
# ---------------------------------------------------------------------------
# Batch data loading (eliminates N+1 per-product queries)
# ---------------------------------------------------------------------------
def batch_load_product_data(conn, products):
"""
Batch-load all per-product data needed for forecasting in a few queries
instead of one query per product.
Returns dict with keys:
preorder_sales: {pid: units} — pre-order units (before first received)
launch_sales: {pid: units} — first 14 days of sales
decay_velocity: {pid: avg} — recent 30-day daily average
mature_history: {pid: DataFrame} — daily sales history for SES
"""
data = {
'preorder_sales': {},
'preorder_days': {},
'launch_sales': {},
'decay_velocity': {},
'mature_history': {},
}
# Pre-order sales: orders placed BEFORE first received date
# Also compute the number of days pre-orders accumulated over (for daily-rate normalization)
preorder_pids = products[products['phase'] == 'preorder']['pid'].tolist()
if preorder_pids:
sql = """
SELECT o.pid,
COALESCE(SUM(o.quantity), 0) AS preorder_units,
GREATEST(EXTRACT(DAY FROM NOW() - MIN(o.date)), 1) AS preorder_days
FROM orders o
LEFT JOIN product_metrics pm ON pm.pid = o.pid
WHERE o.pid = ANY(%s)
AND o.canceled IS DISTINCT FROM TRUE
AND (pm.date_first_received IS NULL OR o.date < pm.date_first_received)
GROUP BY o.pid
"""
df = execute_query(conn, sql, [preorder_pids])
for _, row in df.iterrows():
data['preorder_sales'][int(row['pid'])] = float(row['preorder_units'])
data['preorder_days'][int(row['pid'])] = float(row['preorder_days'])
log.info(f"Batch loaded pre-order sales for {len(data['preorder_sales'])}/{len(preorder_pids)} preorder products")
# Launch sales: first 14 days after first received
launch_pids = products[products['phase'] == 'launch']['pid'].tolist()
if launch_pids:
sql = """
SELECT dps.pid, COALESCE(SUM(dps.units_sold), 0) AS total_sold
FROM daily_product_snapshots dps
JOIN product_metrics pm ON pm.pid = dps.pid
WHERE dps.pid = ANY(%s)
AND dps.snapshot_date >= pm.date_first_received::date
AND dps.snapshot_date < pm.date_first_received::date + INTERVAL '14 days'
GROUP BY dps.pid
"""
df = execute_query(conn, sql, [launch_pids])
for _, row in df.iterrows():
data['launch_sales'][int(row['pid'])] = float(row['total_sold'])
log.info(f"Batch loaded launch sales for {len(data['launch_sales'])}/{len(launch_pids)} launch products")
# Decay recent velocity: average daily sales over last 30 days
decay_pids = products[products['phase'] == 'decay']['pid'].tolist()
if decay_pids:
sql = """
SELECT dps.pid, AVG(COALESCE(dps.units_sold, 0)) AS avg_daily
FROM daily_product_snapshots dps
WHERE dps.pid = ANY(%s)
AND dps.snapshot_date >= CURRENT_DATE - INTERVAL '30 days'
GROUP BY dps.pid
"""
df = execute_query(conn, sql, [decay_pids])
for _, row in df.iterrows():
data['decay_velocity'][int(row['pid'])] = float(row['avg_daily'])
log.info(f"Batch loaded decay velocity for {len(data['decay_velocity'])}/{len(decay_pids)} decay products")
# Mature daily history: full time series for exponential smoothing
mature_pids = products[products['phase'] == 'mature']['pid'].tolist()
if mature_pids:
sql = """
SELECT dps.pid, dps.snapshot_date, COALESCE(dps.units_sold, 0) AS units_sold
FROM daily_product_snapshots dps
WHERE dps.pid = ANY(%s)
AND dps.snapshot_date >= CURRENT_DATE - INTERVAL '1 day' * %s
ORDER BY dps.pid, dps.snapshot_date
"""
df = execute_query(conn, sql, [mature_pids, EXP_SMOOTHING_WINDOW])
for pid, group in df.groupby('pid'):
data['mature_history'][int(pid)] = group.copy()
log.info(f"Batch loaded history for {len(data['mature_history'])}/{len(mature_pids)} mature products")
return data
# ---------------------------------------------------------------------------
# Per-product scale factor computation
# ---------------------------------------------------------------------------
def compute_scale_factor(phase, product, curve_info, batch_data):
"""
Compute the per-product scale factor for the brand curve.
The scale factor captures how much more/less this product sells compared
to the brand average. It's applied to the decay envelope only (not baseline).
"""
if curve_info is None:
return 1.0
pid = int(product['pid'])
amplitude, decay_rate, baseline, median_fw, median_preorder, med_preorder_days = curve_info
if phase == 'preorder':
preorder_units = batch_data['preorder_sales'].get(pid, 0)
preorder_days = batch_data['preorder_days'].get(pid, 1)
baskets = product.get('baskets') or 0
# Too few days of accumulation → noisy signal, use brand average
if preorder_days < MIN_PREORDER_DAYS and preorder_units > 0:
scale = 1.0
return max(0.1, min(scale, 5.0))
# Use order units as primary signal; fall back to baskets if no orders
demand_signal = preorder_units if preorder_units > 0 else baskets
signal_days = preorder_days if preorder_units > 0 else max(preorder_days, 14)
# Normalize to daily rate before comparing to brand median daily rate.
# Use the brand's stored median pre-order window for the denominator
# (not the current product's signal_days) to avoid systematic bias.
demand_daily = demand_signal / max(signal_days, 1)
if median_preorder and median_preorder > 0:
brand_preorder_window = max(med_preorder_days or signal_days, 1)
median_preorder_daily = median_preorder / brand_preorder_window
scale = demand_daily / median_preorder_daily
elif median_fw > 0 and demand_daily > 0:
median_fw_daily = median_fw / 7.0
scale = demand_daily / median_fw_daily
else:
scale = 1.0
elif phase == 'launch':
actual_sold = batch_data['launch_sales'].get(pid, 0)
age = max(0, product.get('age_days') or 0)
if median_fw > 0 and actual_sold > 0:
days_observed = min(age, 14)
if days_observed > 0:
projected_first_week = (actual_sold / days_observed) * 7
scale = projected_first_week / median_fw
else:
scale = 1.0
else:
scale = 1.0
elif phase == 'decay':
actual_velocity = batch_data['decay_velocity'].get(pid, 0)
age = max(0, product.get('age_days') or 0)
t_weeks = age / 7.0
# With baseline fix: value = (A/7)*exp(-λt)*scale + C/7
# Solve for scale: scale = (actual - C/7) / ((A/7)*exp(-λt))
decay_part = (amplitude / 7.0) * np.exp(-decay_rate * t_weeks)
# Use a higher floor for the denominator at high ages to prevent
# extreme scale factors when the decay envelope is nearly zero
min_decay = max(0.01, amplitude / 70.0) # at least 10% of week-1 daily value
if decay_part > min_decay and actual_velocity > 0:
scale = (actual_velocity - baseline / 7.0) / decay_part
elif actual_velocity > 0:
scale = 1.0
else:
scale = 1.0
else:
scale = 1.0
# Clamp to avoid extreme values — tighter for preorder since the signal
# is noisier (pre-orders accumulate differently than post-launch sales)
max_scale = 5.0 if phase == 'preorder' else 8.0
return max(0.1, min(scale, max_scale))
# ---------------------------------------------------------------------------
# Mature product forecast (Holt's double exponential smoothing)
# ---------------------------------------------------------------------------
def forecast_mature(product, history_df):
"""
Forecast for a mature/evergreen product using Holt's linear trend method
on recent daily sales history. Holt's adds a trend component over SES,
so it naturally pulls the forecast back down after a sales spike instead
of persisting the inflated level.
Falls back to SES then flat velocity on failure.
"""
pid = int(product['pid'])
velocity = product.get('sales_velocity_daily') or 0
if history_df is None or history_df.empty or len(history_df) < 7:
# Not enough data — flat velocity
return np.full(FORECAST_HORIZON_DAYS, velocity)
# Fill date gaps with 0 sales (days where product had no snapshot = no sales)
hist = history_df.copy()
hist['snapshot_date'] = pd.to_datetime(hist['snapshot_date'])
hist = hist.set_index('snapshot_date').resample('D').sum().fillna(0)
series = hist['units_sold'].values.astype(float)
# Need at least 2 non-zero values for smoothing
if np.count_nonzero(series) < 2:
return np.full(FORECAST_HORIZON_DAYS, velocity)
# Cap: prevent runaway forecasts from one-time spikes.
# Use the higher of 30d velocity or the observed mean as the baseline,
# so sustained increases are respected.
observed_mean = float(np.mean(series))
cap = max(velocity, observed_mean) * MAX_SMOOTHING_MULTIPLIER
try:
# Holt's with damped trend: the phi parameter dampens the trend over
# the horizon so forecasts converge to a level instead of extrapolating
# a linear trend indefinitely.
model = Holt(series, initialization_method='estimated', damped_trend=True)
fit = model.fit(optimized=True)
forecast = fit.forecast(FORECAST_HORIZON_DAYS)
forecast = np.clip(forecast, 0, cap)
return forecast
except Exception:
# Fall back to SES if Holt's fails (e.g. insufficient data points)
try:
model = SimpleExpSmoothing(series, initialization_method='estimated')
fit = model.fit(optimized=True)
forecast = fit.forecast(FORECAST_HORIZON_DAYS)
forecast = np.clip(forecast, 0, cap)
return forecast
except Exception as e:
log.debug(f"ExpSmoothing failed for pid {pid}: {e}")
return np.full(FORECAST_HORIZON_DAYS, velocity)
def forecast_dormant():
"""Dormant products get near-zero forecast."""
return np.zeros(FORECAST_HORIZON_DAYS)
# ---------------------------------------------------------------------------
# Accuracy-driven confidence margins
# ---------------------------------------------------------------------------
DEFAULT_MARGINS = {
'preorder': 0.4,
'launch': 0.35,
'decay': 0.3,
'mature': 0.35,
'slow_mover': 0.5,
'dormant': 0.5,
}
MIN_MARGIN = 0.15 # intervals shouldn't be tighter than ±15%
MAX_MARGIN = 1.0 # intervals shouldn't exceed ±100%
def load_accuracy_margins(conn):
"""
Load per-phase WMAPE from the most recent forecast accuracy run.
Returns a dict of phase -> base_margin, falling back to DEFAULT_MARGINS.
WMAPE is already a ratio (e.g. 1.7 = 170%), which we use directly as margin.
"""
margins = dict(DEFAULT_MARGINS)
try:
df = execute_query(conn, """
SELECT fa.dimension_value AS phase, fa.wmape
FROM forecast_accuracy fa
JOIN forecast_runs fr ON fr.id = fa.run_id
WHERE fa.metric_type = 'by_phase'
AND fr.status IN ('completed', 'backfill')
AND fa.wmape IS NOT NULL
ORDER BY fr.finished_at DESC
""")
if df.empty:
log.info("No accuracy data available, using default confidence margins")
return margins
# Take the most recent run's values (they appear first due to ORDER BY)
seen = set()
for _, row in df.iterrows():
phase = row['phase']
if phase not in seen:
wmape = float(row['wmape'])
margins[phase] = max(MIN_MARGIN, min(wmape, MAX_MARGIN))
seen.add(phase)
log.info(f"Loaded accuracy-based margins: {', '.join(f'{k}={v:.2f}' for k, v in margins.items())}")
except Exception as e:
log.warning(f"Could not load accuracy margins, using defaults: {e}")
return margins
# ---------------------------------------------------------------------------
# Main orchestration
# ---------------------------------------------------------------------------
FLUSH_EVERY_PRODUCTS = 5000 # Flush forecast rows to DB every N products
def generate_all_forecasts(conn, curves_df, dow_indices, monthly_indices=None,
accuracy_margins=None):
"""Classify all products, batch-load data, generate and stream-write forecasts.
Writes forecast rows to product_forecasts in chunks to avoid accumulating
millions of rows in memory (37K products × 90 days = 3.3M rows).
"""
if monthly_indices is None:
monthly_indices = {m: 1.0 for m in range(1, 13)}
if accuracy_margins is None:
accuracy_margins = dict(DEFAULT_MARGINS)
log.info("Loading products for classification...")
products = load_products(conn)
log.info(f"Loaded {len(products)} visible products")
# Classify each product
products['phase'] = products.apply(classify_phase, axis=1)
phase_counts = products['phase'].value_counts().to_dict()
log.info(f"Phase distribution: {phase_counts}")
# Batch-load per-product data (replaces per-product queries)
log.info("Batch loading product data...")
batch_data = batch_load_product_data(conn, products)
today = date.today()
forecast_dates = [today + timedelta(days=i) for i in range(FORECAST_HORIZON_DAYS)]
# Pre-compute DOW and seasonal multipliers for each forecast date
dow_multipliers = [dow_indices.get(d.isoweekday(), 1.0) for d in forecast_dates]
seasonal_multipliers = [monthly_indices.get(d.month, 1.0) for d in forecast_dates]
# TRUNCATE before streaming writes
with conn.cursor() as cur:
cur.execute("TRUNCATE product_forecasts")
conn.commit()
buffer = []
methods = {}
processed = 0
errors = 0
total_rows = 0
insert_sql = """
INSERT INTO product_forecasts
(pid, forecast_date, forecast_units, forecast_revenue,
lifecycle_phase, forecast_method, confidence_lower,
confidence_upper)
VALUES %s
"""
def flush_buffer():
nonlocal buffer, total_rows
if not buffer:
return
with conn.cursor() as cur:
psycopg2.extras.execute_values(
cur, insert_sql, buffer,
template="(%s, %s, %s, %s, %s, %s, %s, %s)",
page_size=BATCH_SIZE,
)
conn.commit()
total_rows += len(buffer)
buffer = []
for _, product in products.iterrows():
pid = int(product['pid'])
phase = product['phase']
price = float(product['current_price'] or 0)
age = max(0, product.get('age_days') or 0)
try:
curve_info = get_curve_for_product(product, curves_df)
if phase in ('preorder', 'launch'):
if curve_info:
scale = compute_scale_factor(phase, product, curve_info, batch_data)
forecasts = forecast_from_curve(curve_info, scale, age, FORECAST_HORIZON_DAYS)
method = 'lifecycle_curve'
else:
# No reliable curve — fall back to velocity if available
velocity = product.get('sales_velocity_daily') or 0
if velocity > 0:
forecasts = np.full(FORECAST_HORIZON_DAYS, velocity)
method = 'velocity'
else:
forecasts = forecast_dormant()
method = 'zero'
elif phase == 'decay':
if curve_info:
scale = compute_scale_factor(phase, product, curve_info, batch_data)
forecasts = forecast_from_curve(curve_info, scale, age, FORECAST_HORIZON_DAYS)
method = 'lifecycle_curve'
else:
velocity = product.get('sales_velocity_daily') or 0
forecasts = np.full(FORECAST_HORIZON_DAYS, velocity)
method = 'velocity'
elif phase == 'mature':
history = batch_data['mature_history'].get(pid)
forecasts = forecast_mature(product, history)
method = 'exp_smoothing'
elif phase == 'slow_mover':
velocity = product.get('sales_velocity_daily') or 0
forecasts = np.full(FORECAST_HORIZON_DAYS, velocity)
method = 'velocity'
else: # dormant
forecasts = forecast_dormant()
method = 'zero'
# Confidence interval: use accuracy-calibrated margins per phase
base_margin = accuracy_margins.get(phase, 0.5)
for i, d in enumerate(forecast_dates):
base_units = float(forecasts[i]) if i < len(forecasts) else 0.0
# Apply day-of-week and seasonal adjustments
units = base_units * dow_multipliers[i] * seasonal_multipliers[i]
# Widen confidence interval as horizon grows: day 0 = base, day 89 ≈ +50% wider
horizon_factor = 1.0 + 0.5 * (i / max(FORECAST_HORIZON_DAYS - 1, 1))
margin = base_margin * horizon_factor
buffer.append((
pid, d,
round(units, 2),
round(units * price, 4),
phase, method,
round(units * max(1 - margin, 0), 2),
round(units * (1 + margin), 2),
))
methods[method] = methods.get(method, 0) + 1
except Exception as e:
log.warning(f"Error forecasting pid {pid}: {e}")
errors += 1
# Write zero forecast so we have complete coverage
for d in forecast_dates:
buffer.append((pid, d, 0, 0, phase, 'zero', 0, 0))
processed += 1
if processed % FLUSH_EVERY_PRODUCTS == 0:
flush_buffer()
log.info(f" Processed {processed}/{len(products)} products ({total_rows} rows written)...")
# Final flush
flush_buffer()
log.info(f"Forecast generation complete. {processed} products, {errors} errors, {total_rows} rows")
log.info(f"Method distribution: {methods}")
return total_rows, processed, phase_counts
def archive_forecasts(conn, run_id):
"""
Copy current product_forecasts into history before they get replaced.
Only archives forecast rows for dates that have already passed,
so we can later compare them against actuals.
"""
with conn.cursor() as cur:
# Ensure history table exists
cur.execute("""
CREATE TABLE IF NOT EXISTS product_forecasts_history (
run_id INT NOT NULL,
pid BIGINT NOT NULL,
forecast_date DATE NOT NULL,
forecast_units NUMERIC(10,2),
forecast_revenue NUMERIC(14,4),
lifecycle_phase TEXT,
forecast_method TEXT,
confidence_lower NUMERIC(10,2),
confidence_upper NUMERIC(10,2),
generated_at TIMESTAMP,
PRIMARY KEY (run_id, pid, forecast_date)
)
""")
cur.execute("CREATE INDEX IF NOT EXISTS idx_pfh_date ON product_forecasts_history(forecast_date)")
cur.execute("CREATE INDEX IF NOT EXISTS idx_pfh_pid_date ON product_forecasts_history(pid, forecast_date)")
# Find the previous completed run (whose forecasts are still in product_forecasts)
cur.execute("""
SELECT id FROM forecast_runs
WHERE status = 'completed'
ORDER BY finished_at DESC
LIMIT 1
""")
prev_run = cur.fetchone()
if prev_run is None:
log.info("No previous completed run found, skipping archive")
conn.commit()
return 0
prev_run_id = prev_run[0]
# Archive only past-date forecasts (where actuals now exist)
cur.execute("""
INSERT INTO product_forecasts_history
(run_id, pid, forecast_date, forecast_units, forecast_revenue,
lifecycle_phase, forecast_method, confidence_lower, confidence_upper, generated_at)
SELECT %s, pid, forecast_date, forecast_units, forecast_revenue,
lifecycle_phase, forecast_method, confidence_lower, confidence_upper, generated_at
FROM product_forecasts
WHERE forecast_date < CURRENT_DATE
ON CONFLICT (run_id, pid, forecast_date) DO NOTHING
""", (prev_run_id,))
archived = cur.rowcount
conn.commit()
if archived > 0:
log.info(f"Archived {archived} historical forecast rows from run {prev_run_id}")
else:
log.info("No past-date forecasts to archive")
# Prune old history (keep 90 days for accuracy analysis)
cur.execute("DELETE FROM product_forecasts_history WHERE forecast_date < CURRENT_DATE - INTERVAL '90 days'")
pruned = cur.rowcount
if pruned > 0:
log.info(f"Pruned {pruned} old history rows (>90 days)")
conn.commit()
return archived
def compute_accuracy(conn, run_id):
"""
Compute forecast accuracy metrics from archived history vs. actual sales.
Joins product_forecasts_history with daily_product_snapshots on
(pid, forecast_date = snapshot_date) to compare forecasted vs. actual units.
Stores results in forecast_accuracy table, broken down by:
- overall: single aggregate row
- by_phase: per lifecycle phase
- by_lead_time: bucketed by how far ahead the forecast was
- by_method: per forecast method
- daily: per forecast_date (for trend charts)
"""
with conn.cursor() as cur:
# Ensure accuracy table exists
cur.execute("""
CREATE TABLE IF NOT EXISTS forecast_accuracy (
run_id INT NOT NULL,
metric_type TEXT NOT NULL,
dimension_value TEXT NOT NULL,
sample_size INT,
total_actual_units NUMERIC(12,2),
total_forecast_units NUMERIC(12,2),
mae NUMERIC(10,4),
wmape NUMERIC(10,4),
bias NUMERIC(10,4),
rmse NUMERIC(10,4),
computed_at TIMESTAMP NOT NULL DEFAULT NOW(),
PRIMARY KEY (run_id, metric_type, dimension_value)
)
""")
conn.commit()
# Check if we have any history to analyze
cur.execute("SELECT COUNT(*) FROM product_forecasts_history")
history_count = cur.fetchone()[0]
if history_count == 0:
log.info("No forecast history available for accuracy computation")
return
# For each (pid, forecast_date) pair, keep only the most recent run's
# forecast row. This prevents double-counting when multiple runs have
# archived forecasts for the same product×date combination.
accuracy_cte = """
WITH ranked_history AS (
SELECT
pfh.*,
fr.started_at,
ROW_NUMBER() OVER (
PARTITION BY pfh.pid, pfh.forecast_date
ORDER BY fr.started_at DESC
) AS rn
FROM product_forecasts_history pfh
JOIN forecast_runs fr ON fr.id = pfh.run_id
),
accuracy AS (
SELECT
rh.lifecycle_phase,
rh.forecast_method,
rh.forecast_date,
(rh.forecast_date - rh.started_at::date) AS lead_days,
rh.forecast_units,
COALESCE(dps.units_sold, 0) AS actual_units,
(rh.forecast_units - COALESCE(dps.units_sold, 0)) AS error,
ABS(rh.forecast_units - COALESCE(dps.units_sold, 0)) AS abs_error
FROM ranked_history rh
LEFT JOIN daily_product_snapshots dps
ON dps.pid = rh.pid AND dps.snapshot_date = rh.forecast_date
WHERE rh.rn = 1
AND NOT (rh.forecast_units = 0 AND COALESCE(dps.units_sold, 0) = 0)
)
"""
# Compute and insert metrics for each dimension
dimensions = {
'overall': "SELECT 'all' AS dim",
'by_phase': "SELECT DISTINCT lifecycle_phase AS dim FROM accuracy",
'by_lead_time': """
SELECT DISTINCT
CASE
WHEN lead_days BETWEEN 0 AND 6 THEN '1-7d'
WHEN lead_days BETWEEN 7 AND 13 THEN '8-14d'
WHEN lead_days BETWEEN 14 AND 29 THEN '15-30d'
WHEN lead_days BETWEEN 30 AND 59 THEN '31-60d'
ELSE '61-90d'
END AS dim
FROM accuracy
""",
'by_method': "SELECT DISTINCT forecast_method AS dim FROM accuracy",
'daily': "SELECT DISTINCT forecast_date::text AS dim FROM accuracy",
}
filter_clauses = {
'overall': "lifecycle_phase != 'dormant'",
'by_phase': "lifecycle_phase = dims.dim",
'by_lead_time': """
CASE
WHEN lead_days BETWEEN 0 AND 6 THEN '1-7d'
WHEN lead_days BETWEEN 7 AND 13 THEN '8-14d'
WHEN lead_days BETWEEN 14 AND 29 THEN '15-30d'
WHEN lead_days BETWEEN 30 AND 59 THEN '31-60d'
ELSE '61-90d'
END = dims.dim
""",
'by_method': "forecast_method = dims.dim",
'daily': "forecast_date::text = dims.dim",
}
total_inserted = 0
for metric_type, dim_query in dimensions.items():
filter_clause = filter_clauses[metric_type]
sql = f"""
{accuracy_cte},
dims AS ({dim_query})
SELECT
dims.dim,
COUNT(*) AS sample_size,
COALESCE(SUM(a.actual_units), 0) AS total_actual,
COALESCE(SUM(a.forecast_units), 0) AS total_forecast,
AVG(a.abs_error) AS mae,
CASE WHEN SUM(a.actual_units) > 0
THEN SUM(a.abs_error) / SUM(a.actual_units)
ELSE NULL END AS wmape,
AVG(a.error) AS bias,
SQRT(AVG(POWER(a.error, 2))) AS rmse
FROM dims
CROSS JOIN accuracy a
WHERE {filter_clause}
GROUP BY dims.dim
"""
cur.execute(sql)
rows = cur.fetchall()
for row in rows:
dim_val, sample_size, total_actual, total_forecast, mae, wmape, bias, rmse = row
cur.execute("""
INSERT INTO forecast_accuracy
(run_id, metric_type, dimension_value, sample_size,
total_actual_units, total_forecast_units, mae, wmape, bias, rmse)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (run_id, metric_type, dimension_value)
DO UPDATE SET
sample_size = EXCLUDED.sample_size,
total_actual_units = EXCLUDED.total_actual_units,
total_forecast_units = EXCLUDED.total_forecast_units,
mae = EXCLUDED.mae, wmape = EXCLUDED.wmape,
bias = EXCLUDED.bias, rmse = EXCLUDED.rmse,
computed_at = NOW()
""", (run_id, metric_type, dim_val, sample_size,
float(total_actual), float(total_forecast),
float(mae) if mae is not None else None,
float(wmape) if wmape is not None else None,
float(bias) if bias is not None else None,
float(rmse) if rmse is not None else None))
total_inserted += 1
conn.commit()
# Prune old accuracy data (keep 90 days of runs + any in-progress run)
cur.execute("""
DELETE FROM forecast_accuracy
WHERE run_id NOT IN (
SELECT id FROM forecast_runs
WHERE finished_at >= NOW() - INTERVAL '90 days'
OR finished_at IS NULL
)
""")
pruned = cur.rowcount
conn.commit()
log.info(f"Accuracy metrics: {total_inserted} rows computed"
+ (f", {pruned} old rows pruned" if pruned > 0 else ""))
def backfill_accuracy_data(conn, backfill_days=30):
"""
Generate retroactive forecast data for the past N days to bootstrap
accuracy metrics. Uses the current brand curves with per-product scaling
to approximate what the model would have predicted for each past day,
then stores results in product_forecasts_history for comparison against
actual snapshots.
This is a model backtest (in-sample), not true out-of-sample accuracy,
but provides much better initial estimates than unscaled brand curves.
"""
backfill_start_time = time.time()
log.info(f"Backfilling {backfill_days} days of accuracy data with per-product scaling...")
# Load DOW indices
dow_indices = compute_dow_indices(conn)
# Load brand curves (already fitted)
curves_df = execute_query(conn, """
SELECT brand, root_category, cat_id, category_level,
amplitude, decay_rate, baseline,
r_squared, median_first_week_sales, median_preorder_sales,
median_preorder_days
FROM brand_lifecycle_curves
""")
# Load products
products = load_products(conn)
products['phase'] = products.apply(classify_phase, axis=1)
# Skip dormant — they forecast 0 and are filtered from accuracy anyway
active = products[products['phase'] != 'dormant'].copy()
log.info(f"Backfilling for {len(active)} non-dormant products "
f"(skipping {len(products) - len(active)} dormant)")
# Batch load product data for per-product scaling
batch_data = batch_load_product_data(conn, active)
today = date.today()
backfill_start = today - timedelta(days=backfill_days)
# Create a synthetic run entry
with conn.cursor() as cur:
cur.execute("""
INSERT INTO forecast_runs
(started_at, finished_at, status, products_forecast,
phase_counts, error_message)
VALUES (%s, NOW(), 'backfill', %s, %s, %s)
RETURNING id
""", (
backfill_start,
len(active),
json.dumps({'backfill_days': backfill_days}),
f'Model backtest: {backfill_days} days with per-product scaling',
))
backfill_run_id = cur.fetchone()[0]
conn.commit()
log.info(f"Created backfill run {backfill_run_id} "
f"(simulated start: {backfill_start})")
# Generate retroactive forecasts
all_rows = []
backfill_dates = [backfill_start + timedelta(days=i)
for i in range(backfill_days)]
for _, product in active.iterrows():
pid = int(product['pid'])
price = float(product['current_price'] or 0)
current_age = product.get('age_days')
velocity = float(product.get('sales_velocity_daily') or 0)
phase = product['phase']
curve_info = get_curve_for_product(product, curves_df)
# Compute per-product scale factor (same logic as main forecast)
scale = compute_scale_factor(phase, product, curve_info, batch_data)
for forecast_date in backfill_dates:
# How many days ago was this date?
days_ago = (today - forecast_date).days
# Product's age on that date
past_age = (current_age - days_ago) if current_age is not None else None
if past_age is not None and past_age < 0:
# Product didn't exist yet on this date
continue
# Determine what phase the product was likely in
if past_age is not None:
if past_age <= LAUNCH_AGE_DAYS:
past_phase = 'launch'
elif past_age <= DECAY_AGE_DAYS:
past_phase = 'decay'
else:
past_phase = phase # use current classification
else:
past_phase = phase
# Compute forecast value for this date
if past_phase in ('launch', 'decay', 'preorder') and curve_info:
amplitude, decay_rate, baseline = curve_info[:3]
age_for_calc = max(0, past_age or 0)
t_weeks = age_for_calc / 7.0
# Use corrected formula: scale only the decay envelope, not the baseline
daily_value = (amplitude / 7.0) * np.exp(-decay_rate * t_weeks) * scale + (baseline / 7.0)
units = max(0.0, float(daily_value))
method = 'lifecycle_curve'
elif past_phase == 'mature' and velocity > 0:
units = velocity
method = 'exp_smoothing'
else:
units = velocity if velocity > 0 else 0.0
method = 'velocity' if velocity > 0 else 'zero'
# Apply DOW multiplier
dow_mult = dow_indices.get(forecast_date.isoweekday(), 1.0)
units *= dow_mult
if units == 0 and method == 'zero':
continue # skip zero-zero rows
revenue = units * price
margin = 0.3 if method == 'lifecycle_curve' else 0.4
all_rows.append((
backfill_run_id, pid, forecast_date,
round(float(units), 2),
round(float(revenue), 4),
past_phase, method,
round(float(units * (1 - margin)), 2),
round(float(units * (1 + margin)), 2),
backfill_start, # generated_at
))
log.info(f"Generated {len(all_rows)} backfill forecast rows")
# Write to history table
if all_rows:
with conn.cursor() as cur:
sql = """
INSERT INTO product_forecasts_history
(run_id, pid, forecast_date, forecast_units, forecast_revenue,
lifecycle_phase, forecast_method, confidence_lower,
confidence_upper, generated_at)
VALUES %s
ON CONFLICT (run_id, pid, forecast_date) DO NOTHING
"""
psycopg2.extras.execute_values(
cur, sql, all_rows,
template="(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)",
page_size=BATCH_SIZE,
)
conn.commit()
log.info(f"Wrote {len(all_rows)} rows to product_forecasts_history")
# Now compute accuracy on the backfilled data
compute_accuracy(conn, backfill_run_id)
# Mark the backfill run as completed
backfill_duration = time.time() - backfill_start_time
with conn.cursor() as cur:
cur.execute("""
UPDATE forecast_runs
SET finished_at = NOW(), status = 'backfill',
duration_seconds = %s
WHERE id = %s
""", (round(backfill_duration, 2), backfill_run_id))
conn.commit()
log.info(f"Backfill complete in {backfill_duration:.1f}s")
return backfill_run_id
def main():
start_time = time.time()
conn = get_connection()
# Clean up any stale "running" entries from prior crashes
cleanup_stale_runs(conn)
# Check for --backfill flag
if '--backfill' in sys.argv:
idx = sys.argv.index('--backfill')
days = int(sys.argv[idx + 1]) if idx + 1 < len(sys.argv) else 30
log.info("=" * 60)
log.info(f"Backfill mode: {days} days")
log.info("=" * 60)
try:
backfill_accuracy_data(conn, days)
finally:
conn.close()
return
log.info("=" * 60)
log.info("Forecast Engine starting")
log.info("=" * 60)
run_id = None
try:
# Record run start
with conn.cursor() as cur:
cur.execute(
"INSERT INTO forecast_runs (started_at, status) VALUES (NOW(), 'running') RETURNING id"
)
run_id = cur.fetchone()[0]
conn.commit()
# Phase 0: Compute day-of-week and monthly seasonal indices
dow_indices = compute_dow_indices(conn)
monthly_indices = compute_monthly_seasonal_indices(conn)
# Phase 1: Build reference curves
curves_df = build_reference_curves(conn)
# Phase 2: Archive historical forecasts (before TRUNCATE in generation)
archive_forecasts(conn, run_id)
# Phase 3: Compute accuracy from archived history vs actuals
compute_accuracy(conn, run_id)
# Phase 3b: Load accuracy-calibrated confidence margins
accuracy_margins = load_accuracy_margins(conn)
# Phase 4: Generate and stream-write forecasts (TRUNCATE + chunked INSERT)
total_rows, products_forecast, phase_counts = generate_all_forecasts(
conn, curves_df, dow_indices, monthly_indices, accuracy_margins
)
duration = time.time() - start_time
# Record run completion (include DOW indices in metadata)
with conn.cursor() as cur:
cur.execute("""
UPDATE forecast_runs
SET finished_at = NOW(), status = 'completed',
products_forecast = %s, phase_counts = %s,
curve_count = %s, duration_seconds = %s
WHERE id = %s
""", (
products_forecast,
json.dumps({
**phase_counts,
'_dow_indices': {str(k): v for k, v in dow_indices.items()},
'_seasonal_indices': {str(k): v for k, v in monthly_indices.items()},
}),
len(curves_df) if not curves_df.empty else 0,
round(duration, 2),
run_id,
))
conn.commit()
log.info("=" * 60)
log.info(f"Forecast complete in {duration:.1f}s")
log.info(f" Products: {products_forecast}")
log.info(f" Curves: {len(curves_df) if not curves_df.empty else 0}")
log.info(f" Phases: {phase_counts}")
log.info(f" Rows written: {total_rows}")
log.info("=" * 60)
except Exception as e:
duration = time.time() - start_time
log.error(f"Forecast engine failed: {e}", exc_info=True)
if run_id:
try:
with conn.cursor() as cur:
cur.execute("""
UPDATE forecast_runs
SET finished_at = NOW(), status = 'failed',
error_message = %s, duration_seconds = %s
WHERE id = %s
""", (str(e), round(duration, 2), run_id))
conn.commit()
except Exception:
pass
sys.exit(1)
finally:
conn.close()
if __name__ == '__main__':
main()