Skip to content

Commit 778050c

Browse files
authored
Merge pull request #11 from MapChillE/fix/UBLE-80-elasticsearch
[Fix] elasticsearch 연동 및 location 추가
2 parents 425ab41 + 97ba145 commit 778050c

File tree

5 files changed

+141
-29
lines changed

5 files changed

+141
-29
lines changed

app/api/recommend.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
from fastapi import APIRouter, Depends, HTTPException, Query
2-
from sqlalchemy.orm import Session
2+
from sqlalchemy.orm import Session, joinedload
33
from sentence_transformers import SentenceTransformer
4-
from sqlalchemy import text
4+
from sqlalchemy import text, func
55
from app.database.connection import get_db
6-
from app.models import Store
6+
from app.models import Store, Brand
77
from app.services.recommend_service import HybridRecommender
88
from geoalchemy2.functions import ST_DWithin, ST_SetSRID, ST_MakePoint, ST_Distance
9+
from geoalchemy2 import Geometry
10+
from geoalchemy2.shape import to_shape
911
import logging
1012
from app.services.collect_user_data import collect_user_data
1113
from collections import defaultdict
1214
from app.database.redis_client import r
1315
import json
16+
from app.database.es import es
17+
1418

1519
router = APIRouter()
1620
model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
@@ -23,6 +27,12 @@ def get_min_rank(benefits: list) -> str:
2327
return b.rank
2428
return "NONE"
2529

30+
def extract_lat_lng(store):
31+
if store.location is None:
32+
return None, None
33+
point = to_shape(store.location)
34+
return point.y, point.x
35+
2636
@router.get("/recommend")
2737
def recommend(
2838
user_id:int,
@@ -33,7 +43,7 @@ def recommend(
3343
):
3444

3545
# 1. 사용자 정보 수집
36-
categories, histories, bookmarks, clicks, searches = collect_user_data(user_id, db)
46+
categories, histories, bookmarks, clicks, searches = collect_user_data(user_id, db, es)
3747

3848
if not (categories or histories or bookmarks or clicks or searches):
3949
raise HTTPException(status_code=404, detail="사용자 정보가 부족합니다.")
@@ -111,7 +121,7 @@ def hybrid_recommend(
111121
logger.error(f"Redis 캐시 확인 중 오류: {e}")
112122

113123
# 1. 사용자 텍스트 정보 수집
114-
categories, histories, bookmarks, clicks, searches = collect_user_data(user_id, db)
124+
categories, histories, bookmarks, clicks, searches = collect_user_data(user_id, db, es)
115125

116126
if not (categories or histories or bookmarks or clicks or searches):
117127
raise HTTPException(status_code=404, detail="사용자 정보가 부족합니다.")
@@ -126,17 +136,20 @@ def hybrid_recommend(
126136
logger.debug(f"Recommendation results for user {user_id}: {results}")
127137

128138
# 4. 위치 기반 필터링: 추천 브랜드 매장 중 반경 km 이내
129-
nearby_stores = db.query(Store).filter(
130-
Store.brand_id.in_(recommended_brand_ids),
131-
ST_DWithin(Store.location, ST_SetSRID(ST_MakePoint(lng, lat), 4326), radius_km * 1000)
139+
store_query = db.query(Store).options(
140+
joinedload(Store.brand).joinedload(Brand.category),
141+
joinedload(Store.brand).joinedload(Brand.benefits)
142+
).filter(
143+
Store.brand_id.in_(recommended_brand_ids),
144+
func.ST_DWithin(Store.location, func.ST_SetSRID(func.ST_MakePoint(lng, lat), 4326), radius_km * 1000)
132145
).order_by(
133-
ST_Distance(Store.location, ST_SetSRID(ST_MakePoint(lng, lat), 4326))
146+
func.ST_Distance(Store.location, func.ST_SetSRID(func.ST_MakePoint(lng, lat), 4326))
134147
).all()
135148

136149
# 매장 중 하나씩 결과 연결
137150
store_map = {}
138151
brand_store_map = defaultdict(list)
139-
for store in nearby_stores:
152+
for store in store_query:
140153
brand_store_map[store.brand_id].append(store)
141154
store_map = {bid: stores[0] for bid, stores in brand_store_map.items()}
142155

@@ -147,10 +160,13 @@ def hybrid_recommend(
147160
continue
148161
store = store_map[brand_id]
149162
brand = store.brand
163+
lat, lng = extract_lat_lng(store)
150164

151165
item = {
152166
"storeId": store.id,
153167
"storeName": store.name,
168+
"latitude": lat,
169+
"longitude": lng,
154170
"category": brand.category.name if brand.category else None,
155171
"description": brand.description,
156172
"isVIPcock": brand.rank_type in ("VIP", "VIP_NORMAL"),

app/database/es.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from elasticsearch import Elasticsearch
2+
from dotenv import load_dotenv
3+
import os
4+
import logging
5+
6+
logger = logging.getLogger(__name__)
7+
8+
load_dotenv()
9+
10+
ELASTICSEARCH_URL = os.getenv("ELASTICSEARCH_URL")
11+
ES_ID = os.getenv("ES_ID")
12+
ES_PW = os.getenv("ES_PW")
13+
14+
if not all([ELASTICSEARCH_URL, ES_ID, ES_PW]):
15+
raise ValueError("필수 Elasticsearch 환경 변수가 설정되지 않았습니다.")
16+
17+
es = Elasticsearch(
18+
ELASTICSEARCH_URL,
19+
basic_auth=(ES_ID, ES_PW),
20+
verify_certs=True
21+
)
22+
23+
#연결 확인
24+
try:
25+
if not es.ping():
26+
raise ConnectionError("Elasticsearch 서버에 연결할 수 없습니다.")
27+
logger.info("Elasticsearch 연결 성공")
28+
except Exception as e:
29+
logger.error(f"Elasticsearch 연결 실패: {e}")
30+
raise

app/services/collect_user_data.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
from sqlalchemy.orm import Session
22
from sqlalchemy import text
3+
from elasticsearch import Elasticsearch
4+
from elasticsearch.helpers import scan
5+
import logging
36

4-
def collect_user_data(user_id: int, db: Session) -> tuple[list, list, list, list, list]:
5-
"""Collect user-related text data from various sources."""
7+
logger = logging.getLogger(__name__)
8+
9+
def collect_user_data(user_id: int, db: Session, es: Elasticsearch) -> tuple[list, list, list, list, list]:
10+
11+
#RDB에서 가져오는 데이터
612
categories = db.execute(text("""
713
SELECT c.name FROM user_category uc
814
JOIN category c ON uc.category_id = c.id
@@ -21,14 +27,35 @@ def collect_user_data(user_id: int, db: Session) -> tuple[list, list, list, list
2127
WHERE bm.user_id = :user_id
2228
"""), {"user_id": user_id}).scalars().all()
2329

24-
clicks = db.execute(text("""
25-
SELECT s.name FROM store_click_log cl
26-
JOIN store s ON cl.store_id = s.id
27-
WHERE cl.user_id = :user_id
28-
"""), {"user_id": user_id}).scalars().all()
30+
# es에서 가져오는 클릭 로그
31+
clicks = []
32+
try:
33+
for doc in scan(es, index="store-click-log", query={
34+
"query": {
35+
"term": {
36+
"userId": user_id
37+
}
38+
}
39+
}):
40+
store_name = doc["_source"].get("storeName")
41+
if store_name:
42+
clicks.append(store_name)
43+
except Exception as e:
44+
logger.error(f"클릭 로그 조회 실패 (user_id: {user_id}: {e})")
2945

30-
searches = db.execute(text("""
31-
SELECT keyword FROM search_log WHERE user_id = :user_id
32-
"""), {"user_id": user_id}).scalars().all()
33-
46+
# es에서 가져오는 검색 로그
47+
searches = []
48+
try:
49+
for doc in scan(es, index="search-log", query={
50+
"query": {
51+
"term": {
52+
"userId": user_id
53+
}
54+
}
55+
}):
56+
keyword = doc["_source"].get("searchKeyword")
57+
if keyword:
58+
searches.append(keyword)
59+
except Exception as e:
60+
logger.error(f"검색 로그 조회 실패 (user_id: {user_id}: {e})")
3461
return categories, histories, bookmarks, clicks, searches

app/services/recommend_service.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@
44
from implicit.als import AlternatingLeastSquares
55
from sqlalchemy.orm import Session
66
from app.models import BrandClickLog, StoreClickLog, BrandEmbedding, Store
7+
from elasticsearch import Elasticsearch
8+
from elasticsearch.helpers import scan
9+
from app.database.es import es
10+
import logging
11+
12+
logger = logging.getLogger(__name__)
13+
714

815
class HybridRecommender:
916
def __init__(self):
@@ -12,15 +19,28 @@ def __init__(self):
1219
self.item_factors = None
1320
self.item_id_to_index = {}
1421
self.index_to_item_id = {}
22+
self.user_id_to_code = {}
23+
self.code_to_user_id = {}
24+
self.es = es
25+
26+
# 로그 가져오는 함수
27+
def get_logs_from_es(self, index_name: str):
28+
logs = []
29+
try:
30+
for doc in scan(self.es, index=index_name):
31+
logs.append(doc["_source"])
32+
except Exception as e:
33+
logger.error(f"Elasticsearch 로그 조회 실패 (index: {index_name}): {e}")
34+
return logs
1535

1636
def train_model(self, db: Session):
17-
store_logs = db.query(StoreClickLog).all()
18-
brand_logs = db.query(BrandClickLog).all()
37+
store_logs = self.get_logs_from_es("store-click-log")
38+
brand_logs = self.get_logs_from_es("brand-click-log")
1939

2040
combined_data = []
2141

22-
# Store 클릭 로그 : brand_id로 매핑
23-
store_ids = [log.store_id for log in store_logs]
42+
# Store 클릭 로그 : store_id -> brand_id로 매핑
43+
store_ids = [log["storeId"] for log in store_logs]
2444
stores_with_brands = db.query(Store.id, Store.brand_id).filter(
2545
Store.id.in_(store_ids),
2646
Store.brand_id.isnot(None)
@@ -29,13 +49,13 @@ def train_model(self, db: Session):
2949
store_to_brand = {store_id: brand_id for store_id, brand_id in stores_with_brands}
3050

3151
for log in store_logs:
32-
brand_id = store_to_brand.get(log.store_id)
52+
brand_id = store_to_brand.get(log["storeId"])
3353
if brand_id:
34-
combined_data.append({'user_id': log.user_id, 'brand_id': brand_id})
54+
combined_data.append({'user_id': log["userId"], "brand_id": brand_id})
3555

3656
# Brand 클릭 로그
3757
for log in brand_logs:
38-
combined_data.append({'user_id': log.user_id, 'brand_id': log.brand_id})
58+
combined_data.append({'user_id': log["userId"], "brand_id": log["brandId"]})
3959

4060
if not combined_data:
4161
return

requirements.txt

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,35 @@
11
annotated-types==0.7.0
22
anyio==4.9.0
3+
certifi==2025.7.9
4+
charset-normalizer==3.4.2
5+
click==8.2.1
6+
elastic-transport==8.17.1
7+
elasticsearch==8.12.1
38
fastapi==0.116.1
9+
filelock==3.18.0
10+
fsspec==2025.5.1
411
GeoAlchemy2==0.17.1
512
h11==0.16.0
13+
hf-xet==1.1.5
14+
huggingface-hub==0.33.4
615
idna==3.10
716
implicit==0.7.2
17+
Jinja2==3.1.6
818
joblib==1.5.1
19+
MarkupSafe==3.0.2
920
mpmath==1.3.0
21+
networkx==3.5
1022
numpy==2.3.1
23+
packaging==25.0
1124
pandas==2.3.1
1225
pgvector==0.4.1
1326
pillow==11.3.0
14-
psycopg2-binary==2.9.10
27+
psycopg2==2.9.10
1528
pydantic==2.11.7
1629
pydantic_core==2.33.2
1730
python-dateutil==2.9.0.post0
1831
python-dotenv==1.1.1
32+
pytz==2025.2
1933
PyYAML==6.0.2
2034
redis==6.2.0
2135
regex==2024.11.6
@@ -24,6 +38,9 @@ safetensors==0.5.3
2438
scikit-learn==1.7.0
2539
scipy==1.16.0
2640
sentence-transformers==5.0.0
41+
setuptools==80.9.0
42+
shapely==2.1.1
43+
six==1.17.0
2744
sniffio==1.3.1
2845
SQLAlchemy==2.0.41
2946
starlette==0.47.1
@@ -35,4 +52,6 @@ tqdm==4.67.1
3552
transformers==4.53.2
3653
typing-inspection==0.4.1
3754
typing_extensions==4.14.1
55+
tzdata==2025.2
56+
urllib3==2.5.0
3857
uvicorn==0.35.0

0 commit comments

Comments
 (0)