diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..098c66ce --- /dev/null +++ b/Dockerfile @@ -0,0 +1,24 @@ +# 1. 가볍고 안정적인 Python 슬림 버전 사용 (버전은 프로젝트에 맞춰 변경 가능) +FROM python:3.10.8 + +# 2. 컨테이너 내부 작업 디렉토리 설정 +WORKDIR /app + +# 3. 캐싱 효율을 위해 requirements.txt 먼저 복사 및 설치 +# (Dockerfile 위치 기준 하위 폴더인 OpenWallet_AI/requirements.txt를 가져옴) +COPY requirements.txt . + +# 4. 의존성 설치 (--no-cache-dir로 이미지 크기 최소화) +RUN pip install --no-cache-dir -r requirements.txt + +# 5. 나머지 소스 코드 복사 +# (OpenWallet_AI 폴더 내부의 모든 파일을 컨테이너의 /app으로 복사) +COPY . . + +# 6. 포트 노출 선언 (문서화 목적 및 일부 도구 지원용) +EXPOSE 8000 + +# 7. FastAPI 실행 (Uvicorn 사용 가정) +# --host 0.0.0.0: 컨테이너 외부에서 접근 가능하게 설정 (필수) +# --port 8081: 쿠버네티스 포트와 일치시킴 (필수) +CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] \ No newline at end of file diff --git a/dockerignore b/dockerignore new file mode 100644 index 00000000..600d2d33 --- /dev/null +++ b/dockerignore @@ -0,0 +1 @@ +.vscode \ No newline at end of file diff --git a/main.py b/main.py index 798bcc9e..67403784 100644 --- a/main.py +++ b/main.py @@ -6,12 +6,11 @@ - 외부 트렌드 요약 (Kanana) - Qwen 기반 개인 소비 리포트 """ - from typing import List, Optional, Dict, Any - -from fastapi import FastAPI, UploadFile, File, Form, HTTPException +from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Depends from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel +from sqlalchemy.orm import Session # 1) 기존 모듈 import # OCR: ocr/main.py 에 있는 로직 재사용 @@ -26,19 +25,11 @@ suggest_category, ) -# 소비 통계/툴: tool.py -from tool import ( - get_total_spend, - get_top_merchants, - get_trend, -) # 트렌드 요약: trend_summary.py from trend_summary import run as run_trend_summary, TrendSummary # Qwen 리포트: report/ 폴더 -from report.schemas import ReportRequest, ReportResponse -from report.db import get_transactions_from_db try: from report.qwen_model import generate_spending_report QWEN_AVAILABLE = True @@ -46,10 +37,14 @@ QWEN_AVAILABLE = False def generate_spending_report(*args, **kwargs): return "(로컬 개발 환경에서는 Qwen 모델을 사용할 수 없습니다.)" + +from report import models +from report import schemas +from report.database import engine, get_db +models.Base.metadata.create_all(bind=engine) # 2) FastAPI 앱 공통 설정 - app = FastAPI( title="OpenWallet Unified API", version="1.0.0", @@ -108,93 +103,6 @@ async def api_ocr_receipt( except Exception as e: raise HTTPException(500, f"OCR 처리 중 오류: {e}") -# 소비 통계 / 집계 API (tool.py 래핑) - -class TotalSpendRequest(BaseModel): - user_id: str - start_date: str - end_date: str - category_ids: Optional[List[int]] = None - merchant_ids: Optional[List[int]] = None - - -class TotalSpendResponse(BaseModel): - total: int - currency: str = "KRW" - - -@app.post("/stats/total-spend", response_model=TotalSpendResponse) -def api_total_spend(req: TotalSpendRequest): - data = get_total_spend( - user_id=req.user_id, - start_date=req.start_date, - end_date=req.end_date, - category_ids=req.category_ids, - merchant_ids=req.merchant_ids, - ) - return TotalSpendResponse(**data) - - -class TopMerchantsRequest(BaseModel): - user_id: str - start_date: str - end_date: str - limit: int = 5 - category_ids: Optional[List[int]] = None - - -class MerchantSummary(BaseModel): - merchant_id: int - merchant_name: str - amount: int - - -class TopMerchantsResponse(BaseModel): - top_merchants: List[MerchantSummary] - currency: str = "KRW" - - -@app.post("/stats/top-merchants", response_model=TopMerchantsResponse) -def api_top_merchants(req: TopMerchantsRequest): - data = get_top_merchants( - user_id=req.user_id, - start_date=req.start_date, - end_date=req.end_date, - limit=req.limit, - category_ids=req.category_ids, - ) - # data 형태는 {"top_merchants": [...], "currency": "KRW"} - return TopMerchantsResponse(**data) - - -class TrendRequest(BaseModel): - user_id: str - period: str = "monthly" # "monthly" | "weekly" - months: int = 6 - category_ids: Optional[List[int]] = None - - -class TrendPoint(BaseModel): - period: str - amount: int - - -class TrendResponse(BaseModel): - series: List[TrendPoint] - currency: str = "KRW" - - -@app.post("/stats/trend", response_model=TrendResponse) -def api_trend(req: TrendRequest): - data = get_trend( - user_id=req.user_id, - period=req.period, - months=req.months, - category_ids=req.category_ids, - ) - # {"series": [{"period": "...", "amount": ...}, ...], "currency": "KRW"} - return TrendResponse(**data) - # 3. 외부 트렌드 요약 API (Kanana + SQLite) class TrendSummaryRequest(BaseModel): @@ -246,48 +154,69 @@ def api_trend_summary(req: TrendSummaryRequest): # 4. Qwen 기반 개인 소비 리포트 API # (기존 report/main.py 로직 그대로) -@app.post("/report", response_model=ReportResponse) -def api_report(request: ReportRequest): - """ - - Request: ReportRequest (user_id, start_date, end_date, question) - - DB에서 거래 내역 조회 후 Qwen으로 리포트 생성 - """ - # 1) DB에서 거래 내역 가져오기 - transactions = get_transactions_from_db( - user_id=request.user_id, - start_date=request.start_date, - end_date=request.end_date, +@app.post("/report", response_model=schemas.ReportResponse) +def create_report(request: schemas.ReportRequest, db: Session = Depends(get_db)): + # 1. DB 조회: 날짜 범위 필터링 + # 지출 입력 API는 없지만, DB에 이미 저장된 'models.Expense' 데이터를 읽어와야 리포트 작성이 가능합니다. + expenses_query = db.query(models.Expense).filter( + models.Expense.date >= request.start_date, + models.Expense.date <= request.end_date ) + expenses = expenses_query.all() - if not transactions: + if not expenses: raise HTTPException( - status_code=404, - detail="해당 조건에 해당하는 거래 내역이 없습니다.", + status_code=404, + detail="해당 기간에 조회된 지출 데이터가 없습니다." ) - - # 2) Qwen 모델로 리포트 생성 - if not QWEN_AVAILABLE: - return ReportResponse( - report="(로컬 개발 환경: Qwen 모델 비활성화됨)", - user_id=request.user_id, - start_date=request.start_date, - end_date=request.end_date, - transaction_count=len(transactions), - ) - - report_text = generate_spending_report( - transactions=transactions, - user_question=request.question, + total_amount = 0 + category_summary = {} # 예: {"FOOD": 50000, "TRANSPORT": 30000} + # 2. 데이터 변환: ORM 객체 -> 딕셔너리 리스트 (모델 입력용) + # 수정사항 카테고리별 합계 계산 + transaction_list = [] + for exp in expenses: + # 1. 전체 합계 계산 + total_amount += exp.price + + # 2. 카테고리별 합계 계산 + cat_name = exp.category + if cat_name not in category_summary: + category_summary[cat_name] = 0 + category_summary[cat_name] += exp.price + + # 3. 상세 내역은 개수 제한 (리포트 전달 개수 조절 가능 현재는 30) + if len(transaction_list) < 30: + transaction_list.append({ + "date": str(exp.date), + "merchant": exp.title, + "amount": exp.price, + "category": exp.category + }) + + # 3. 모델에게 줄 데이터 재구성 + # 상세 내역 대신 요약 정보를 줍니다. + summary_text = { + "total_spent": total_amount, + "category_breakdown": category_summary, + "recent_transactions_sample": transaction_list # 샘플만 전달 + } + + # 3. 모델 추론: 리포트 생성 + try: + report_text = generate_spending_report( + transactions=transaction_list, + user_question=request.question ) + except Exception as e: + print(f"LLM Generation Error: {e}") + raise HTTPException(status_code=500, detail=f"리포트 생성 중 오류가 발생했습니다: {str(e)}") - - # 3) 응답 구성 - return ReportResponse( + # 4. 결과 반환 + return schemas.ReportResponse( report=report_text, - user_id=request.user_id, start_date=request.start_date, end_date=request.end_date, - transaction_count=len(transactions), + transaction_count=len(expenses) ) # 5. Health Check diff --git a/report/models.py b/report/models.py index fe3f1511..263b158f 100644 --- a/report/models.py +++ b/report/models.py @@ -1,7 +1,22 @@ # 작성일 : 25/11/30 + from sqlalchemy import Column, String, Integer, Date, Text -from database import Base import uuid +import sys +import os + +# 1. 현재 파일(models.py)의 위치를 구합니다. +current_dir = os.path.dirname(os.path.abspath(__file__)) + +# 2. 부모 폴더(report)의 부모 폴더(OpenWallet_AI 루트)를 구합니다. +root_dir = os.path.dirname(current_dir) + +# 3. 시스템 경로(sys.path)에 루트 폴더가 없다면 추가합니다. +if root_dir not in sys.path: + sys.path.append(root_dir) + + +from .database import Base # UUID 생성을 위한 함수 def generate_uuid(): diff --git a/requirements.txt b/requirements.txt index 290303a2..b5cd5316 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,3 +33,8 @@ openai==1.33.0 # === Utils === tqdm>=4.66.4 protobuf>=4.25.3 + +# === report에 필요한거 === +pymysql +cryptography +huggingface_hub[hf_xet] \ No newline at end of file diff --git a/tool.py b/tool.py deleted file mode 100644 index a9f1bf1c..00000000 --- a/tool.py +++ /dev/null @@ -1,69 +0,0 @@ -from datetime import datetime -from collections import defaultdict -from typing import List, Optional, Dict - - -# --- Mock 데이터: (실서비스에서는 DB로 대체) -RECEIPTS = [ - # user_id, purchased_at(YYYY-MM-DD), merchant_id, category_id, total_amount - ("u_123", "2025-11-01", 10, 21, 4500), # 카페 - ("u_123", "2025-11-03", 10, 21, 6100), # 카페 - ("u_123", "2025-11-10", 11, 7, 38000), # 마트 - ("u_123", "2025-10-11", 10, 21, 5200), -] -MERCHANTS = {10: "스타카페", 11: "하이퍼마트"} -CATEGORIES = {21: "카페/커피", 7: "식료품"} - - -def _in_range(d: str, start: str, end: str) -> bool: - return (start <= d) and (d < end) - - -def get_total_spend(user_id: str, start_date: str, end_date: str, - category_ids: Optional[List[int]] = None, - merchant_ids: Optional[List[int]] = None) -> Dict: - - - s = 0 - for uid, d, mid, cid, amt in RECEIPTS: - if uid != user_id: - continue - if not _in_range(d, start_date, end_date): - continue - if category_ids and cid not in category_ids: - continue - if merchant_ids and mid not in merchant_ids: - continue - s += amt - return {"total": s, "currency": "KRW"} - - -def get_top_merchants(user_id: str, start_date: str, end_date: str, - limit: int = 5, category_ids: Optional[List[int]] = None) -> Dict: - agg = defaultdict(int) - for uid, d, mid, cid, amt in RECEIPTS: - if uid != user_id or not _in_range(d, start_date, end_date): - continue - if category_ids and cid not in category_ids: - continue - agg[mid] += amt - ranked = sorted(agg.items(), key=lambda x: x[1], reverse=True)[:limit] - result = [{"merchant_id": mid, "merchant_name": MERCHANTS.get( - mid, str(mid)), "amount": amt} for mid, amt in ranked] - return {"top_merchants": result, "currency": "KRW"} - - -def get_trend(user_id: str, period: str = "monthly", months: int = 6, - category_ids: Optional[List[int]] = None) -> Dict: - # 단순 mock: YYYY-MM 단위 합계 - box = defaultdict(int) - for uid, d, mid, cid, amt in RECEIPTS: - if uid != user_id: - continue - if category_ids and cid not in category_ids: - continue - ym = d[:7] - box[ym] += amt - series = sorted([{"period": k, "amount": v} - for k, v in box.items()], key=lambda x: x["period"]) - return {"series": series, "currency": "KRW"}