Skip to content

Commit

Permalink
Merge branch 'main' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
egg528 authored Apr 26, 2024
2 parents 39c9f54 + 72425c5 commit b0b9112
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package io.raemian.api.support

import jakarta.servlet.http.Cookie
import jakarta.servlet.http.HttpServletRequest
import jakarta.servlet.http.HttpServletResponse
import org.springframework.util.SerializationUtils
import java.util.Base64
import java.util.Optional

object CookieUtils {
fun getCookie(request: HttpServletRequest, name: String): Optional<Cookie> {
val cookies = request.cookies
if (cookies != null && cookies.isNotEmpty()) {
for (cookie in cookies) {
if (cookie.name == name) {
return Optional.of(cookie)
}
}
}
return Optional.empty()
}

fun addCookie(response: HttpServletResponse, name: String, value: String, maxAge: Int) {
val cookie = Cookie(name, value)
cookie.path = "/"
cookie.isHttpOnly = true
cookie.maxAge = maxAge
response.addCookie(cookie)
}

fun deleteCookie(request: HttpServletRequest, response: HttpServletResponse, name: String) {
val cookies = request.cookies
if (cookies != null && cookies.isNotEmpty()) {
for (cookie in cookies) {
if (cookie.name == name) {
cookie.value = ""
cookie.path = "/"
cookie.maxAge = 0
response.addCookie(cookie)
}
}
}
}

fun serialize(obj: Any?): String {
return Base64.getUrlEncoder()
.encodeToString(
SerializationUtils.serialize(obj),
)
}

fun <T> deserialize(cookie: Cookie, cls: Class<T>): T {
return cls.cast(SerializationUtils.deserialize(Base64.getUrlDecoder().decode(cookie.value)))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package io.raemian.api.support

import jakarta.servlet.http.Cookie
import jakarta.servlet.http.HttpServletRequest
import jakarta.servlet.http.HttpServletResponse
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest
import org.springframework.stereotype.Component
import java.time.Duration

@Component
class HttpCookieOAuth2AuthorizationRequestRepository() : AuthorizationRequestRepository<OAuth2AuthorizationRequest> {

private val AUTHORIZATION_REQUEST_COOKIE_NAME = "oauth2_auth_request"
private val EXPIRE_SECONDS: Int = Duration.ofSeconds(180).toMillis().toInt()

override fun loadAuthorizationRequest(request: HttpServletRequest): OAuth2AuthorizationRequest? {
val state = this.getStateParameter(request) ?: return null

val authorizationRequest: OAuth2AuthorizationRequest? =
CookieUtils.getCookie(request, AUTHORIZATION_REQUEST_COOKIE_NAME)
.map { cookie: Cookie ->
CookieUtils.deserialize(cookie, OAuth2AuthorizationRequest::class.java)
}.orElse(null)

return if (authorizationRequest != null && state == authorizationRequest.state) {
authorizationRequest
} else {
null
}
}
override fun saveAuthorizationRequest(
authorizationRequest: OAuth2AuthorizationRequest?,
request: HttpServletRequest,
response: HttpServletResponse,
) {
if (authorizationRequest == null) {
removeAuthorizationRequest(request, response)
return
}

CookieUtils.addCookie(
response,
AUTHORIZATION_REQUEST_COOKIE_NAME,
CookieUtils.serialize(authorizationRequest),
EXPIRE_SECONDS,
)
}

override fun removeAuthorizationRequest(
request: HttpServletRequest,
response: HttpServletResponse,
): OAuth2AuthorizationRequest? {
val authorizationRequest: OAuth2AuthorizationRequest? = this.loadAuthorizationRequest(request)

if (authorizationRequest != null) {
removeAuthorizationRequestCookies(request, response)
}

return authorizationRequest
}

private fun removeAuthorizationRequestCookies(request: HttpServletRequest, response: HttpServletResponse) {
CookieUtils.deleteCookie(request, response, AUTHORIZATION_REQUEST_COOKIE_NAME)
}

private fun getStateParameter(request: HttpServletRequest): String? = request.getParameter("state")
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class UserService(
return UserResult.of(userRepository.save(updated))
}


fun update(id: Long, nickname: String, birth: LocalDate?, username: String, image: String): UserResult {
val user = userRepository.getById(id)

Expand Down

0 comments on commit b0b9112

Please sign in to comment.