Skip to content

Commit

Permalink
More efficient working with headers (#289)
Browse files Browse the repository at this point in the history
Co-authored-by: Andriy Plokhotnyuk <[email protected]>
  • Loading branch information
plokhotnyuk and Andriy Plokhotnyuk authored Aug 8, 2023
1 parent 493eef1 commit 27d9bc3
Show file tree
Hide file tree
Showing 11 changed files with 278 additions and 208 deletions.
108 changes: 53 additions & 55 deletions core/src/main/scala/sttp/model/MediaType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package sttp.model
import sttp.model.ContentTypeRange.Wildcard
import sttp.model.internal.Rfc2616._
import sttp.model.internal.Validate._
import sttp.model.internal.{Patterns, Validate}
import sttp.model.internal.Patterns

import java.nio.charset.Charset

Expand All @@ -17,20 +17,13 @@ case class MediaType(
def charset(c: String): MediaType = copy(charset = Some(c))
def noCharset: MediaType = copy(charset = None)

def matches(range: ContentTypeRange): Boolean = {
def charsetMatches: Boolean =
if (range.charset == Wildcard) true
// #2994 from tapir: when the media type doesn't define a charset, it shouldn't be taken into account in the matching logic
else charset.isEmpty || charset.map(_.toLowerCase).contains(range.charset.toLowerCase)

(range match {
case ContentTypeRange(Wildcard, _, _) => true
case ContentTypeRange(mainType, Wildcard, _) => this.mainType.equalsIgnoreCase(mainType)
case ContentTypeRange(mainType, subType, _) =>
this.mainType.equalsIgnoreCase(mainType) && this.subType.equalsIgnoreCase(subType)
case null => false
}) && charsetMatches
}
// #2994 from tapir: when the media type doesn't define a charset, it shouldn't be taken into account in the matching logic
def matches(range: ContentTypeRange): Boolean =
range != null &&
(range.mainType == Wildcard ||
mainType.equalsIgnoreCase(range.mainType) &&
(range.subType == Wildcard || subType.equalsIgnoreCase(range.subType))) &&
(range.charset == Wildcard || charset.forall(_.equalsIgnoreCase(range.charset)))

def isApplication: Boolean = mainType.equalsIgnoreCase("application")
def isAudio: Boolean = mainType.equalsIgnoreCase("audio")
Expand All @@ -43,10 +36,22 @@ case class MediaType(
def isExample: Boolean = mainType.equalsIgnoreCase("example")
def isModel: Boolean = mainType.equalsIgnoreCase("model")

override def toString: String = s"$mainType/$subType" + charset.fold("")(c => s"; charset=$c") +
otherParameters.foldLeft("") { case (s, (p, v)) => if (p == "charset") s else s"$s; $p=$v" }
override def toString: String = {
val sb = new java.lang.StringBuilder(32) // "application/json; charset=utf-8".length == 31 ;)
sb.append(mainType).append('/').append(subType)
charset match {
case x: Some[String] => sb.append("; charset=").append(x.value)
case _ => ()
}
otherParameters.foreach { case (p, v) =>
if (p != "charset") sb.append("; ").append(p).append('=').append(v)
else ()
}
sb.toString
}

override lazy val hashCode: Int = toString.toLowerCase.hashCode

override def hashCode(): Int = toString.toLowerCase.hashCode
override def equals(that: Any): Boolean =
that match {
case t: AnyRef if this.eq(t) => true
Expand Down Expand Up @@ -77,39 +82,31 @@ object MediaType extends MediaTypes {
subType: String,
charset: Option[String] = None,
parameters: Map[String, String] = Map.empty
): Either[String, MediaType] = {
Validate.all(
Seq(
validateToken("Main type", mainType),
validateToken("Sub type", subType),
charset.flatMap(validateToken("Charset", _))
) ++ parameters.map { case (p, v) => validateToken(p, v) }: _*
)(
apply(mainType, subType, charset, parameters)
)
}
): Either[String, MediaType] =
validateToken("Main type", mainType)
.orElse(validateToken("Sub type", subType))
.orElse(charset.flatMap(validateToken("Charset", _)))
.orElse(parameters.collectFirst {
case (p, v) if validateToken(p, v).isDefined => validateToken(p, v).get
}) match {
case None => Right(apply(mainType, subType, charset, parameters))
case Some(error) => Left(error)
}

// based on https://github.com/square/okhttp/blob/20cd3a0/okhttp/src/main/java/okhttp3/MediaType.kt#L94
def parse(t: String): Either[String, MediaType] = {
val typeSubtype = Patterns.TypeSubtype.matcher(t)
if (!typeSubtype.lookingAt()) {
return Left(s"""No subtype found for: "$t"""")
}

val (mainType, subType) = (typeSubtype.group(1), typeSubtype.group(2), typeSubtype.group(3)) match {
// if there are nulls indicating no main and subtype then we expect a single * (group 3)
// it's invalid according to rfc but is used by `HttpUrlConnection` https://bugs.openjdk.java.net/browse/JDK-8163921
case (null, null, Wildcard) => (Wildcard, Wildcard)
case (mainType, subType, _) => (mainType.toLowerCase, subType.toLowerCase)
}

val parameters = Patterns.parseMediaTypeParameters(t, offset = typeSubtype.end())

parameters match {
case Right(params) =>
Right(MediaType(mainType, subType, params.get("charset"), params.filter { case (p, _) => p != "charset" }))
case Left(error) => Left(error)
}
if (typeSubtype.lookingAt()) {
val (mainType, subType) = (typeSubtype.group(1), typeSubtype.group(2), typeSubtype.group(3)) match {
// if there are nulls indicating no main and subtype then we expect a single * (group 3)
// it's invalid according to rfc but is used by `HttpUrlConnection` https://bugs.openjdk.java.net/browse/JDK-8163921
case (null, null, Wildcard) => (Wildcard, Wildcard)
case (mainType, subType, _) => (mainType.toLowerCase, subType.toLowerCase)
}
Patterns
.parseMediaTypeParameters(t, offset = typeSubtype.end())
.map(params => MediaType(mainType, subType, params.get("charset"), params - "charset"))
} else Left(s"""No subtype found for: "$t"""")
}

def unsafeParse(s: String): MediaType = parse(s).getOrThrow
Expand All @@ -120,16 +117,17 @@ object MediaType extends MediaTypes {
* Content type ranges, sorted in order of preference.
*/
def bestMatch(mediaTypes: Seq[MediaType], ranges: Seq[ContentTypeRange]): Option[MediaType] = {
mediaTypes
.map(mt => mt -> ranges.indexWhere(mt.matches))
.filter({ case (_, i) => i != NotFoundIndex }) // not acceptable
match {
case Nil => None
case mts => Some(mts.minBy({ case (_, i) => i })).map { case (mt, _) => mt }
var minMt: MediaType = null
var minIndex = Int.MaxValue
mediaTypes.foreach { mt =>
val index = ranges.indexWhere(mt.matches)
if (index >= 0 && minIndex > index) {
minIndex = index
minMt = mt
}
}
Option(minMt)
}

private val NotFoundIndex = -1
}

// https://www.iana.org/assignments/media-types/media-types.xhtml
Expand Down
34 changes: 11 additions & 23 deletions core/src/main/scala/sttp/model/headers/AcceptEncoding.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ case class AcceptEncoding(encodings: List[WeightedEncoding]) {
object AcceptEncoding {

case class WeightedEncoding(encoding: String, weight: Option[BigDecimal]) {
override def toString: String = s"$encoding${weight.map(w => s";q=$w").getOrElse("")}"
override def toString: String = weight.fold(encoding)(w => s"$encoding;q=$w")
}

def parse(str: String): Either[String, AcceptEncoding] = {
val encodings = processString(str, List.empty)
val encodings = processString(str)
if (encodings.isEmpty) Left(s"No encodings found in: $str")
else {
@tailrec
Expand All @@ -34,40 +34,28 @@ object AcceptEncoding {
}
}

@tailrec
private def processString(str: String, acc: List[WeightedEncoding]): List[WeightedEncoding] = {
str.trim.split(",").toList match {
case x :: tail if x.nonEmpty =>
val range = parsSingleEncoding(x)
processString(tail.mkString(","), range :: acc)
case Nil => List(parsSingleEncoding(str))
case _ => acc
}
}
private def processString(str: String): List[WeightedEncoding] =
str.trim.split(",").map(x => parsSingleEncoding(x.trim)).reverse.toList // TODO: do we really need `.reverse` here?

private def parsSingleEncoding(s: String): WeightedEncoding = {
private def parsSingleEncoding(s: String): WeightedEncoding =
s.split(";") match {
case Array(algorithm) => WeightedEncoding(algorithm, None)
case Array(algorithm, weight) =>
weight.split("=") match {
case Array(_, value) => WeightedEncoding(algorithm, Some(BigDecimal(value)))
case _ => WeightedEncoding("", None)
}
case Array(algorithm) => WeightedEncoding(algorithm, None)
case _ => WeightedEncoding("", None)
case _ => WeightedEncoding("", None)
}
}

private def validate(acceptEncoding: WeightedEncoding, original: => String): Either[String, WeightedEncoding] = {
if (acceptEncoding.encoding.isEmpty) {
Left(s"Invalid empty encoding in: $original")
} else {
private def validate(acceptEncoding: WeightedEncoding, original: => String): Either[String, WeightedEncoding] =
if (acceptEncoding.encoding.isEmpty) Left(s"Invalid empty encoding in: $original")
else
acceptEncoding.weight match {
case Some(value) if value < 0 || value > 1 =>
Left(s"Invalid weight, expected a number between 0 and 1, but got: $value in $original.")
case _ => Right(acceptEncoding)
}
}
}

def unsafeParse(s: String): AcceptEncoding = parse(s).getOrThrow

Expand All @@ -76,6 +64,6 @@ object AcceptEncoding {

def safeApply(encoding: String, weight: Option[BigDecimal]): Either[String, AcceptEncoding] = {
val encodingObject = WeightedEncoding(encoding, weight)
validate(encodingObject, encodingObject.toString).right.map(e => AcceptEncoding(List(e)))
validate(encodingObject, encodingObject.toString).map(e => AcceptEncoding(List(e)))
}
}
100 changes: 63 additions & 37 deletions core/src/main/scala/sttp/model/headers/Accepts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,68 +27,94 @@ object Accepts {
charsets: Seq[(String, Float)]
): Seq[ContentTypeRange] = {
(mediaTypes, charsets) match {
case (Nil, Nil) => Seq(AnyRange)
case (Nil, Nil) => AnyRange :: Nil
case (Nil, (ch, _) :: Nil) => ContentTypeRange(Wildcard, Wildcard, ch) :: Nil
case ((mt, _) :: Nil, Nil) => ContentTypeRange(mt.mainType, mt.subType, Wildcard) :: Nil
case (Nil, chs) =>
chs.sortBy({ case (_, q) => -q }).map { case (ch, _) => ContentTypeRange(Wildcard, Wildcard, ch) }
case (mts, Nil) =>
mts.sortBy({ case (_, q) => -q }).map { case (mt, _) => ContentTypeRange(mt.mainType, mt.subType, Wildcard) }
case (mts, chs) =>
val merged = mts.flatMap { case (mt, mtQ) =>
mts.flatMap { case (mt, mtQ) =>
// if Accept-Charset is defined then any other charset specified in Accept header in not acceptable
chs.map { case (ch, chQ) => (mt, ch) -> math.min(mtQ, chQ) }
} match {
case ((mt, ch), _) :: Nil => ContentTypeRange(mt.mainType, mt.subType, ch) :: Nil
case merged =>
merged.sortBy({ case (_, q) => -q }).map { case ((mt, ch), _) =>
ContentTypeRange(mt.mainType, mt.subType, ch)
}
}
merged.sortBy({ case (_, q) => -q }).map { case ((mt, ch), _) => ContentTypeRange(mt.mainType, mt.subType, ch) }
}
}

private def parseAcceptHeader(headers: Seq[Header]): Either[String, Seq[(MediaType, Float)]] = {
extractEntries(headers, HeaderNames.Accept)
.map(entry => MediaType.parse(entry).right.flatMap(mt => qValue(mt).right.map(mt -> _)))
.partition(_.isLeft) match {
case (Nil, mts) => Right(mts collect { case Right(mtWithQ) => mtWithQ })
case (errors, _) => Left(errors collect { case Left(msg) => msg } mkString "\n")
val errors = new java.lang.StringBuilder()
val mts = List.newBuilder[(MediaType, Float)]
extractEntries(headers, HeaderNames.Accept).foreach { entry =>
MediaType.parse(entry).flatMap(mt => qValue(mt).map(mt -> _)) match {
case Right(mt) =>
mts += mt
case Left(error) =>
if (errors.length != 0) errors.append('\n')
else ()
errors.append(error)
}
}
if (errors.length == 0) Right(mts.result())
else Left(errors.toString)
}

private def parseAcceptCharsetHeader(headers: Seq[Header]): Either[String, Seq[(String, Float)]] =
extractEntries(headers, HeaderNames.AcceptCharset)
.map(parseAcceptCharsetEntry)
.partition(_.isLeft) match {
case (Nil, chs) => Right(chs collect { case Right(ch) => ch })
case (errors, _) => Left(errors collect { case Left(msg) => msg } mkString "\n")
private def parseAcceptCharsetHeader(headers: Seq[Header]): Either[String, Seq[(String, Float)]] = {
val errors = new java.lang.StringBuilder()
val chs = List.newBuilder[(String, Float)]
extractEntries(headers, HeaderNames.AcceptCharset).foreach { entry =>
parseAcceptCharsetEntry(entry) match {
case Right(ch) => chs += ch
case Left(error) =>
if (errors.length != 0) errors.append('\n')
else ()
errors.append(error)
}
}
if (errors.length == 0) Right(chs.result())
else Left(errors.toString)
}

private def parseAcceptCharsetEntry(entry: String): Either[String, (String, Float)] = {
val name = Patterns.Type.matcher(entry)
if (!name.lookingAt()) {
Left(s"""No charset found for: "$entry"""")
} else {
Patterns.parseMediaTypeParameters(entry, offset = name.end()) match {
case Right(params) =>
qValueFrom(params) match {
case Right(q) => Right(name.group(1).toLowerCase -> q)
case Left(error) => Left(error)
}
case Left(error) => Left(error)
}
if (name.lookingAt()) {
Patterns
.parseMediaTypeParameters(entry, offset = name.end())
.flatMap(qValueFrom(_).map(name.group(1).toLowerCase -> _))
} else Left(s"""No charset found for: "$entry"""")
}

private def extractEntries(headers: Seq[Header], name: String): Seq[String] = {
val entries = List.newBuilder[String]
headers.foreach { h =>
if (h.is(name)) entries ++= trimInPlace(h.value.split(","))
}
entries.result()
}

private def extractEntries(headers: Seq[Header], name: String): Seq[String] =
headers
.filter(_.is(name))
.flatMap(_.value.split(","))
.map(_.replaceAll(Patterns.WhiteSpaces, ""))
private def trimInPlace(ss: Array[String]): Array[String] = {
var i = 0
while (i < ss.length) {
ss(i) = ss(i).trim
i += 1
}
ss
}

private def qValue(mt: MediaType): Either[String, Float] = qValueFrom(mt.otherParameters)

private def qValueFrom(parameters: Map[String, String]): Either[String, Float] =
parameters.get("q") collect { case Patterns.QValue(q) => q.toFloat } match {
case Some(value) => Right(value)
case None =>
parameters
.get("q")
.map(q => Left(s"""q must be numeric value between <0, 1> with up to 3 decimal points, provided "$q""""))
.getOrElse(Right(1f))
parameters.get("q") match {
case None => Right(1f)
case Some(q) =>
val qValue = Patterns.QValue.matcher(q)
if (qValue.matches() && qValue.groupCount() == 1) Right(qValue.group(1).toFloat)
else Left(s"""q must be numeric value between <0, 1> with up to 3 decimal points, provided "$q"""")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ object AuthenticationScheme {
val qopValue = params.getOrElse(qop, "")
val qopValueMatch = qopValues.exists(_.equals(qopValue))
if (!containsNonce) Left(s"Missing nonce parameter in: $params")
else if (!containsOpaque) Left("Missing opaque parameter in: $params")
else if (!qopValueMatch) Left("qop value incorrect in: $params")
else if (!containsOpaque) Left(s"Missing opaque parameter in: $params")
else if (!qopValueMatch) Left(s"qop value incorrect in: $params")
else Right(())
}

Expand Down
Loading

0 comments on commit 27d9bc3

Please sign in to comment.