import io from datetime import date from pathlib import Path from flask import Blueprint, current_app, request from flask_jwt_extended import jwt_required from sqlalchemy import func, select from extensions import db from models import SessionPhoto, ShootingSession from .utils import ( created, err, no_content, ok, current_api_user, serialize_analysis, serialize_session, serialize_session_photo, ) sessions_bp = Blueprint("api_sessions", __name__, url_prefix="/sessions") # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _int_or_none(v): try: result = int(v) return result if result > 0 else None except (TypeError, ValueError): return None def _float_or_none(v): try: return float(v) if v is not None and str(v).strip() else None except (TypeError, ValueError): return None # --------------------------------------------------------------------------- # Routes # --------------------------------------------------------------------------- @sessions_bp.get("/") @jwt_required() def list_sessions(): user = current_api_user() if not user: return err("User not found.", 404) try: page = max(1, int(request.args.get("page", 1))) per_page = min(100, max(1, int(request.args.get("per_page", 20)))) except (TypeError, ValueError): page, per_page = 1, 20 total = db.session.scalar( select(func.count()).select_from(ShootingSession) .where(ShootingSession.user_id == user.id) ) or 0 sessions = db.session.scalars( select(ShootingSession) .where(ShootingSession.user_id == user.id) .order_by(ShootingSession.session_date.desc(), ShootingSession.created_at.desc()) .offset((page - 1) * per_page) .limit(per_page) ).all() return ok({ "data": [serialize_session(s) for s in sessions], "total": total, "page": page, "per_page": per_page, }) @sessions_bp.post("/") @jwt_required() def create_session(): user = current_api_user() if not user: return err("User not found.", 404) body = request.get_json(silent=True) or {} date_str = (body.get("session_date") or "").strip() if not date_str: return err("session_date is required.", 400) try: session_date = date.fromisoformat(date_str) except ValueError: return err("session_date must be a valid ISO date string (YYYY-MM-DD).", 400) s = ShootingSession(user_id=user.id, session_date=session_date) s.is_public = bool(body.get("is_public", False)) s.location_name = (body.get("location_name") or "").strip() or None s.location_lat = _float_or_none(body.get("location_lat")) s.location_lon = _float_or_none(body.get("location_lon")) s.distance_m = _int_or_none(body.get("distance_m")) s.weather_cond = (body.get("weather_cond") or "").strip() or None s.weather_temp_c = _float_or_none(body.get("weather_temp_c")) s.weather_wind_kph = _float_or_none(body.get("weather_wind_kph")) s.rifle_id = _int_or_none(body.get("rifle_id")) s.scope_id = _int_or_none(body.get("scope_id")) s.ammo_brand = (body.get("ammo_brand") or "").strip() or None s.ammo_weight_gr = _float_or_none(body.get("ammo_weight_gr")) s.ammo_lot = (body.get("ammo_lot") or "").strip() or None s.notes = (body.get("notes") or "").strip() or None db.session.add(s) db.session.commit() return created(serialize_session(s)) @sessions_bp.get("/") @jwt_required(optional=True) def get_session(session_id: int): s = db.session.get(ShootingSession, session_id) if not s: return err("Session not found.", 404) user = current_api_user() is_owner = user and s.user_id == user.id if not s.is_public and not is_owner: return err("Access denied.", 403) return ok(serialize_session(s, include_user=True)) @sessions_bp.patch("/") @jwt_required() def update_session(session_id: int): user = current_api_user() if not user: return err("User not found.", 404) s = db.session.get(ShootingSession, session_id) if not s: return err("Session not found.", 404) if s.user_id != user.id: return err("Access denied.", 403) body = request.get_json(silent=True) or {} if "session_date" in body: try: s.session_date = date.fromisoformat(body["session_date"]) except (ValueError, TypeError): return err("session_date must be a valid ISO date string (YYYY-MM-DD).", 400) if "is_public" in body: s.is_public = bool(body["is_public"]) for analysis in s.analyses: analysis.is_public = s.is_public if "location_name" in body: s.location_name = (body["location_name"] or "").strip() or None if "location_lat" in body: s.location_lat = _float_or_none(body["location_lat"]) if "location_lon" in body: s.location_lon = _float_or_none(body["location_lon"]) if "distance_m" in body: s.distance_m = _int_or_none(body["distance_m"]) if "weather_cond" in body: s.weather_cond = (body["weather_cond"] or "").strip() or None if "weather_temp_c" in body: s.weather_temp_c = _float_or_none(body["weather_temp_c"]) if "weather_wind_kph" in body: s.weather_wind_kph = _float_or_none(body["weather_wind_kph"]) if "rifle_id" in body: s.rifle_id = _int_or_none(body["rifle_id"]) if "scope_id" in body: s.scope_id = _int_or_none(body["scope_id"]) if "ammo_brand" in body: s.ammo_brand = (body["ammo_brand"] or "").strip() or None if "ammo_weight_gr" in body: s.ammo_weight_gr = _float_or_none(body["ammo_weight_gr"]) if "ammo_lot" in body: s.ammo_lot = (body["ammo_lot"] or "").strip() or None if "notes" in body: s.notes = (body["notes"] or "").strip() or None db.session.commit() return ok(serialize_session(s)) @sessions_bp.delete("/") @jwt_required() def delete_session(session_id: int): user = current_api_user() if not user: return err("User not found.", 404) s = db.session.get(ShootingSession, session_id) if not s: return err("Session not found.", 404) if s.user_id != user.id: return err("Access denied.", 403) storage_root = current_app.config["STORAGE_ROOT"] for photo in s.photos: try: (Path(storage_root) / photo.photo_path).unlink(missing_ok=True) except Exception: pass db.session.delete(s) db.session.commit() return no_content() # --------------------------------------------------------------------------- # Photos # --------------------------------------------------------------------------- @sessions_bp.post("//photos") @jwt_required() def upload_photo(session_id: int): user = current_api_user() if not user: return err("User not found.", 404) s = db.session.get(ShootingSession, session_id) if not s: return err("Session not found.", 404) if s.user_id != user.id: return err("Access denied.", 403) photo_file = request.files.get("photo") if not photo_file or not photo_file.filename: return err("No photo file provided.", 400) from storage import save_session_photo try: photo_path = save_session_photo(user.id, session_id, photo_file) except ValueError as e: return err(str(e), 422) caption = (request.form.get("caption") or "").strip() or None photo = SessionPhoto(session_id=session_id, photo_path=photo_path, caption=caption) db.session.add(photo) db.session.commit() return created(serialize_session_photo(photo)) @sessions_bp.delete("//photos/") @jwt_required() def delete_photo(session_id: int, photo_id: int): user = current_api_user() if not user: return err("User not found.", 404) s = db.session.get(ShootingSession, session_id) if not s: return err("Session not found.", 404) if s.user_id != user.id: return err("Access denied.", 403) photo = db.session.get(SessionPhoto, photo_id) if not photo or photo.session_id != session_id: return err("Photo not found.", 404) storage_root = current_app.config["STORAGE_ROOT"] try: (Path(storage_root) / photo.photo_path).unlink(missing_ok=True) except Exception: pass db.session.delete(photo) db.session.commit() return no_content() # --------------------------------------------------------------------------- # CSV upload # --------------------------------------------------------------------------- @sessions_bp.post("//csv") @jwt_required() def upload_csv(session_id: int): user = current_api_user() if not user: return err("User not found.", 404) s = db.session.get(ShootingSession, session_id) if not s: return err("Session not found.", 404) if s.user_id != user.id: return err("Access denied.", 403) csv_file = request.files.get("csv_file") if not csv_file or not csv_file.filename: return err("No csv_file provided.", 400) from analyzer.parser import parse_csv from analyzer.grouper import detect_groups from analyzer.stats import compute_overall_stats, compute_group_stats from analyzer.charts import render_group_charts, render_overview_chart from analyzer.pdf_report import generate_pdf from storage import save_analysis try: csv_bytes = csv_file.read() df = parse_csv(io.BytesIO(csv_bytes)) groups = detect_groups(df) overall = compute_overall_stats(df) group_stats = compute_group_stats(groups) charts = render_group_charts(groups, y_min=overall["min_speed"], y_max=overall["max_speed"]) overview_chart = render_overview_chart(group_stats) pdf_bytes = generate_pdf(overall, group_stats, charts, overview_chart) except ValueError as e: return err(str(e), 422) analysis_id = save_analysis( user=user, csv_bytes=csv_bytes, pdf_bytes=pdf_bytes, overall=overall, group_stats=group_stats, filename=csv_file.filename or "upload.csv", session_id=session_id, is_public=s.is_public, ) from models import Analysis analysis = db.session.get(Analysis, analysis_id) return created(serialize_analysis(analysis))