Files
ShooterHub/blueprints/api/sessions.py

328 lines
10 KiB
Python

import io
from datetime import date
from pathlib import Path
from flask import Blueprint, current_app, request
from flask_jwt_extended import jwt_required
from sqlalchemy import func, select
from extensions import db
from models import SessionPhoto, ShootingSession
from .utils import (
created, err, no_content, ok,
current_api_user, serialize_analysis, serialize_session, serialize_session_photo,
)
sessions_bp = Blueprint("api_sessions", __name__, url_prefix="/sessions")
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _int_or_none(v):
try:
result = int(v)
return result if result > 0 else None
except (TypeError, ValueError):
return None
def _float_or_none(v):
try:
return float(v) if v is not None and str(v).strip() else None
except (TypeError, ValueError):
return None
# ---------------------------------------------------------------------------
# Routes
# ---------------------------------------------------------------------------
@sessions_bp.get("/")
@jwt_required()
def list_sessions():
user = current_api_user()
if not user:
return err("User not found.", 404)
try:
page = max(1, int(request.args.get("page", 1)))
per_page = min(100, max(1, int(request.args.get("per_page", 20))))
except (TypeError, ValueError):
page, per_page = 1, 20
total = db.session.scalar(
select(func.count()).select_from(ShootingSession)
.where(ShootingSession.user_id == user.id)
) or 0
sessions = db.session.scalars(
select(ShootingSession)
.where(ShootingSession.user_id == user.id)
.order_by(ShootingSession.session_date.desc(), ShootingSession.created_at.desc())
.offset((page - 1) * per_page)
.limit(per_page)
).all()
return ok({
"data": [serialize_session(s) for s in sessions],
"total": total,
"page": page,
"per_page": per_page,
})
@sessions_bp.post("/")
@jwt_required()
def create_session():
user = current_api_user()
if not user:
return err("User not found.", 404)
body = request.get_json(silent=True) or {}
date_str = (body.get("session_date") or "").strip()
if not date_str:
return err("session_date is required.", 400)
try:
session_date = date.fromisoformat(date_str)
except ValueError:
return err("session_date must be a valid ISO date string (YYYY-MM-DD).", 400)
s = ShootingSession(user_id=user.id, session_date=session_date)
s.is_public = bool(body.get("is_public", False))
s.location_name = (body.get("location_name") or "").strip() or None
s.location_lat = _float_or_none(body.get("location_lat"))
s.location_lon = _float_or_none(body.get("location_lon"))
s.distance_m = _int_or_none(body.get("distance_m"))
s.weather_cond = (body.get("weather_cond") or "").strip() or None
s.weather_temp_c = _float_or_none(body.get("weather_temp_c"))
s.weather_wind_kph = _float_or_none(body.get("weather_wind_kph"))
s.rifle_id = _int_or_none(body.get("rifle_id"))
s.scope_id = _int_or_none(body.get("scope_id"))
s.ammo_brand = (body.get("ammo_brand") or "").strip() or None
s.ammo_weight_gr = _float_or_none(body.get("ammo_weight_gr"))
s.ammo_lot = (body.get("ammo_lot") or "").strip() or None
s.notes = (body.get("notes") or "").strip() or None
db.session.add(s)
db.session.commit()
return created(serialize_session(s))
@sessions_bp.get("/<int:session_id>")
@jwt_required(optional=True)
def get_session(session_id: int):
s = db.session.get(ShootingSession, session_id)
if not s:
return err("Session not found.", 404)
user = current_api_user()
is_owner = user and s.user_id == user.id
if not s.is_public and not is_owner:
return err("Access denied.", 403)
return ok(serialize_session(s, include_user=True))
@sessions_bp.patch("/<int:session_id>")
@jwt_required()
def update_session(session_id: int):
user = current_api_user()
if not user:
return err("User not found.", 404)
s = db.session.get(ShootingSession, session_id)
if not s:
return err("Session not found.", 404)
if s.user_id != user.id:
return err("Access denied.", 403)
body = request.get_json(silent=True) or {}
if "session_date" in body:
try:
s.session_date = date.fromisoformat(body["session_date"])
except (ValueError, TypeError):
return err("session_date must be a valid ISO date string (YYYY-MM-DD).", 400)
if "is_public" in body:
s.is_public = bool(body["is_public"])
for analysis in s.analyses:
analysis.is_public = s.is_public
if "location_name" in body:
s.location_name = (body["location_name"] or "").strip() or None
if "location_lat" in body:
s.location_lat = _float_or_none(body["location_lat"])
if "location_lon" in body:
s.location_lon = _float_or_none(body["location_lon"])
if "distance_m" in body:
s.distance_m = _int_or_none(body["distance_m"])
if "weather_cond" in body:
s.weather_cond = (body["weather_cond"] or "").strip() or None
if "weather_temp_c" in body:
s.weather_temp_c = _float_or_none(body["weather_temp_c"])
if "weather_wind_kph" in body:
s.weather_wind_kph = _float_or_none(body["weather_wind_kph"])
if "rifle_id" in body:
s.rifle_id = _int_or_none(body["rifle_id"])
if "scope_id" in body:
s.scope_id = _int_or_none(body["scope_id"])
if "ammo_brand" in body:
s.ammo_brand = (body["ammo_brand"] or "").strip() or None
if "ammo_weight_gr" in body:
s.ammo_weight_gr = _float_or_none(body["ammo_weight_gr"])
if "ammo_lot" in body:
s.ammo_lot = (body["ammo_lot"] or "").strip() or None
if "notes" in body:
s.notes = (body["notes"] or "").strip() or None
db.session.commit()
return ok(serialize_session(s))
@sessions_bp.delete("/<int:session_id>")
@jwt_required()
def delete_session(session_id: int):
user = current_api_user()
if not user:
return err("User not found.", 404)
s = db.session.get(ShootingSession, session_id)
if not s:
return err("Session not found.", 404)
if s.user_id != user.id:
return err("Access denied.", 403)
storage_root = current_app.config["STORAGE_ROOT"]
for photo in s.photos:
try:
(Path(storage_root) / photo.photo_path).unlink(missing_ok=True)
except Exception:
pass
db.session.delete(s)
db.session.commit()
return no_content()
# ---------------------------------------------------------------------------
# Photos
# ---------------------------------------------------------------------------
@sessions_bp.post("/<int:session_id>/photos")
@jwt_required()
def upload_photo(session_id: int):
user = current_api_user()
if not user:
return err("User not found.", 404)
s = db.session.get(ShootingSession, session_id)
if not s:
return err("Session not found.", 404)
if s.user_id != user.id:
return err("Access denied.", 403)
photo_file = request.files.get("photo")
if not photo_file or not photo_file.filename:
return err("No photo file provided.", 400)
from storage import save_session_photo
try:
photo_path = save_session_photo(user.id, session_id, photo_file)
except ValueError as e:
return err(str(e), 422)
caption = (request.form.get("caption") or "").strip() or None
photo = SessionPhoto(session_id=session_id, photo_path=photo_path, caption=caption)
db.session.add(photo)
db.session.commit()
return created(serialize_session_photo(photo))
@sessions_bp.delete("/<int:session_id>/photos/<int:photo_id>")
@jwt_required()
def delete_photo(session_id: int, photo_id: int):
user = current_api_user()
if not user:
return err("User not found.", 404)
s = db.session.get(ShootingSession, session_id)
if not s:
return err("Session not found.", 404)
if s.user_id != user.id:
return err("Access denied.", 403)
photo = db.session.get(SessionPhoto, photo_id)
if not photo or photo.session_id != session_id:
return err("Photo not found.", 404)
storage_root = current_app.config["STORAGE_ROOT"]
try:
(Path(storage_root) / photo.photo_path).unlink(missing_ok=True)
except Exception:
pass
db.session.delete(photo)
db.session.commit()
return no_content()
# ---------------------------------------------------------------------------
# CSV upload
# ---------------------------------------------------------------------------
@sessions_bp.post("/<int:session_id>/csv")
@jwt_required()
def upload_csv(session_id: int):
user = current_api_user()
if not user:
return err("User not found.", 404)
s = db.session.get(ShootingSession, session_id)
if not s:
return err("Session not found.", 404)
if s.user_id != user.id:
return err("Access denied.", 403)
csv_file = request.files.get("csv_file")
if not csv_file or not csv_file.filename:
return err("No csv_file provided.", 400)
from analyzer.parser import parse_csv
from analyzer.grouper import detect_groups
from analyzer.stats import compute_overall_stats, compute_group_stats
from analyzer.charts import render_group_charts, render_overview_chart
from analyzer.pdf_report import generate_pdf
from storage import save_analysis
try:
csv_bytes = csv_file.read()
df = parse_csv(io.BytesIO(csv_bytes))
groups = detect_groups(df)
overall = compute_overall_stats(df)
group_stats = compute_group_stats(groups)
charts = render_group_charts(groups, y_min=overall["min_speed"], y_max=overall["max_speed"])
overview_chart = render_overview_chart(group_stats)
pdf_bytes = generate_pdf(overall, group_stats, charts, overview_chart)
except ValueError as e:
return err(str(e), 422)
analysis_id = save_analysis(
user=user,
csv_bytes=csv_bytes,
pdf_bytes=pdf_bytes,
overall=overall,
group_stats=group_stats,
filename=csv_file.filename or "upload.csv",
session_id=session_id,
is_public=s.is_public,
)
from models import Analysis
analysis = db.session.get(Analysis, analysis_id)
return created(serialize_analysis(analysis))