Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,13 @@ lazy val core = projectMatrix
"smithy4s.http.HttpUnaryServerRouter.partialFunction"
),
// originating in an Alloy update that removed ProtoCompactOffsetDateTime
ProblemFilters.exclude[MissingClassProblem]("alloy.proto.ProtoCompactOffsetDateTime"),
ProblemFilters.exclude[MissingClassProblem](
"alloy.proto.ProtoCompactOffsetDateTime"
),
// originating in an Alloy update that removed ProtoCompactOffsetDateTime
ProblemFilters.exclude[MissingClassProblem]("alloy.proto.ProtoCompactOffsetDateTime$"),
ProblemFilters.exclude[MissingClassProblem](
"alloy.proto.ProtoCompactOffsetDateTime$"
)
)
)
.jvmPlatform(allJvmScalaVersions, jvmDimSettings)
Expand Down Expand Up @@ -439,7 +443,7 @@ lazy val codegen = projectMatrix
.in(file("modules/codegen"))
.enablePlugins(BuildInfoPlugin)
.dependsOn(protocol)
.jvmPlatform(buildtimejvmScala2Versions, jvmDimSettings)
.jvmPlatform(allJvmScalaVersions, jvmDimSettings)
.settings(
buildInfoKeys := Seq[BuildInfoKey](
version,
Expand All @@ -463,14 +467,32 @@ lazy val codegen = projectMatrix
Dependencies.Circe.core.value,
Dependencies.Circe.parser.value,
Dependencies.Circe.generic.value,
Dependencies.collectionsCompat.value,
"org.scala-lang" % "scala-reflect" % scalaVersion.value,
"io.get-coursier" %% "coursier" % "2.1.24",
Dependencies.Mima.core % Test
Dependencies.collectionsCompat.value
),
libraryDependencies += {
CrossVersion.partialVersion(scalaVersion.value) match {
case Some((2, _)) => Dependencies.coursier_2
case Some((3, _)) => Dependencies.coursier_3
case other => sys.error(s"unsupported scala version $other")
}
},
libraryDependencies ++= {
CrossVersion.partialVersion(scalaVersion.value) match {
case Some((2, _)) =>
Seq("org.scala-lang" % "scala-reflect" % scalaVersion.value)
case _ => Seq.empty
}
},
libraryDependencies ++= munitDeps.value,
scalacOptions := scalacOptions.value
.filterNot(Seq("-Ywarn-value-discard", "-Wvalue-discard").contains),
scalacOptions ++= {
CrossVersion.partialVersion(scalaVersion.value) match {
case Some((3, _)) =>
Seq("-Wconf:cat=deprecation:silent")
case _ => Nil
}
},
bloopEnabled := true,
Compile / sourceGenerators += {
sourceManaged
Expand Down
86 changes: 51 additions & 35 deletions modules/codegen/src/smithy4s/codegen/internals/Renderer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1821,8 +1821,8 @@ private[internals] class Renderer(compilationUnit: CompilationUnit) { self =>
// NOTE: this match doesn't have exhaustivity checking on Scala 2! (due to the Aux pattern's weird interaction with gADTs)
prim match {
case Primitive.BigDecimal =>
(bd: BigDecimal) => line"scala.math.BigDecimal($bd)"
case Primitive.BigInteger => (bi: BigInt) => line"scala.math.BigInt($bi)"
((bd: BigDecimal) => line"scala.math.BigDecimal($bd)").asInstanceOf[T => Line]
case Primitive.BigInteger => ((bi: BigInt) => line"scala.math.BigInt($bi)").asInstanceOf[T => Line]
case Primitive.Unit => _ => line"()"
case Primitive.Double => t => line"${t.toString}d"
case Primitive.Float => t => line"${t.toString}f"
Expand All @@ -1831,60 +1831,76 @@ private[internals] class Renderer(compilationUnit: CompilationUnit) { self =>
case Primitive.Short => t => line"${t.toString}"
case Primitive.Bool => t => line"${t.toString}"
case Primitive.Uuid => uuid => line"java.util.UUID.fromString(${renderStringLiteral(uuid.toString)})"
case Primitive.String => renderStringLiteral
case Primitive.String => ((str: String) => renderStringLiteral(str)).asInstanceOf[T => Line]
case Primitive.Byte => b => line"${b.toString}"
case Primitive.Blob =>
ba =>
{ (ba: Array[Byte]) =>
val blob = NameRef("smithy4s", "Blob")
if (ba.isEmpty) line"$blob.empty"
else
line"$blob(Array[Byte](${ba.mkString(", ")}))"
}.asInstanceOf[T => Line]
case Primitive.Timestamp =>
ts => line"${NameRef("smithy4s", "Timestamp")}(${ts.getEpochSecond()}L, ${ts.getNano()})"
case Primitive.Document => { (node: Node) =>
node.accept(new NodeVisitor[Line] {
def arrayNode(x: ArrayNode): Line = {
val innerValues = x.getElements().asScala.map(_.accept(this))
line"smithy4s.Document.array(${innerValues.toList.intercalate(Line.comma)})"
}
def booleanNode(x: BooleanNode): Line =
line"smithy4s.Document.fromBoolean(${x.getValue})"
def nullNode(x: NullNode): Line =
line"smithy4s.Document.nullDoc"
def numberNode(x: NumberNode): Line =
line"smithy4s.Document.fromDouble(${x.getValue.doubleValue()}d)"
def objectNode(x: ObjectNode): Line = {
val members = x.getMembers.asScala.map { member =>
val key = renderStringLiteral(member._1.getValue)
val value = member._2.accept(this)
line"$key -> $value"
((ts: java.time.Instant) => line"${NameRef("smithy4s", "Timestamp")}(${ts.getEpochSecond()}L, ${ts.getNano()})")
.asInstanceOf[T => Line]
case Primitive.Document =>
{ (node: Node) =>
node.accept(new NodeVisitor[Line] {
def arrayNode(x: ArrayNode): Line = {
val innerValues = x.getElements().asScala.map(_.accept(this))
line"smithy4s.Document.array(${innerValues.toList.intercalate(Line.comma)})"
}
line"smithy4s.Document.obj(${members.toList.intercalate(Line.comma)})"
}
def stringNode(x: StringNode): Line =
line"""smithy4s.Document.fromString(${renderStringLiteral(
x.getValue
)})"""
})
}
case Primitive.Nothing => v => (v: Nothing) // this case can't happen
def booleanNode(x: BooleanNode): Line =
line"smithy4s.Document.fromBoolean(${x.getValue})"
def nullNode(x: NullNode): Line =
line"smithy4s.Document.nullDoc"
def numberNode(x: NumberNode): Line =
line"smithy4s.Document.fromDouble(${x.getValue.doubleValue()}d)"
def objectNode(x: ObjectNode): Line = {
val members = x.getMembers.asScala.map { member =>
val key = renderStringLiteral(member._1.getValue)
val value = member._2.accept(this)
line"$key -> $value"
}
line"smithy4s.Document.obj(${members.toList.intercalate(Line.comma)})"
}
def stringNode(x: StringNode): Line =
line"""smithy4s.Document.fromString(${renderStringLiteral(
x.getValue
)})"""
})
}.asInstanceOf[T => Line]
case Primitive.Nothing => _ => sys.error("cannot happen") // this case can't happen
}

private def renderStringLiteral(raw: String): Line = {
import scala.reflect.runtime.universe._
val str = Literal(Constant(raw))
.toString()
// Escape the string as a Scala literal (surround with quotes, escape special characters)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

homestly I think we should just use scalameta for small things like these (rendering constants)

Copy link
Contributor

@Baccata Baccata May 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fine, if you want to pull it I don't mind. But don't get your hopes up about rewriting the codegen in scalameta

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I'm not

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried it in the past and while having the snippets to compile at the same time you compile your code is mind-blowing there was a lot of hassle sometimes to get them right.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, the ancestor to smithy4s was actually written with scalameta. Lots of friction.

Copy link
Member

@kubukoz kubukoz Jun 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hints:

  • libraryDependencies += ("org.scalameta" %% "scalameta" % "4.11.0").cross(CrossVersion.for3Use2_13)
  • scala.meta.Lit.String(s).printSyntaxFor(scala.meta.dialects.Scala3)

that's how I use it in https://github.com/kubukoz/smithy-playground/blob/c34be16f534795041295f4379a6ca8040297c4b1/modules/parser-gen/src/main/scala/playground/parsergen/ParserGen.scala#L355C25-L355C92 anyway

val str = raw
.flatMap {
case '"' => "\\\""
case '\\' => "\\\\"
case '\n' => "\\n"
case '\r' => "\\r"
case '\t' => "\\t"
case c if c.isControl => f"\\u${c.toInt}%04x"
case c => c.toString
}
.mkString("\"", "", "\"")
// Replace sequences like "\\uD83D" (how Smithy specs refer to unicode characters)
// with unicode character escapes like "\uD83D" that can be parsed in the regex implementations on all platforms.
// See https://github.com/disneystreaming/smithy4s/pull/499
.replace("\\\\u", "\\u")

// If the string contains "$", the use of -Xlint:missing-interpolator when
// compiling the generated code will result in warnings. To prevent that we
// render any such strings as interpolated strings (even though that would
// otherwise be unecessary) so that we can render "$" as "$$", which get
// converted back to "$" during interpolation.
val escaped = if (str.contains('$')) s"s${str.replace("$", "$$")}" else str
val escaped =
if (str.contains('$')) s"s${str.replace("$", "$$")}"
else str

line"$escaped"
}

}
18 changes: 12 additions & 6 deletions modules/codegen/src/smithy4s/codegen/internals/SmithyToIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -825,11 +825,10 @@ private[codegen] class SmithyToIR(
def unionShape(x: UnionShape): Option[Type] =
Type.Ref(x.namespace, x.name).some

def memberShape(x: MemberShape): Option[Type] =
model.getShape(x.getTarget()).asScala.flatMap { shape =>
val builder =
(Shape.shapeToBuilder(shape: Shape): AbstractShapeBuilder[_, _])

def memberShape(x: MemberShape): Option[Type] = {
def processBuilder[S <: Shape, B <: AbstractShapeBuilder[B, S]](
builder: AbstractShapeBuilder[B, S]
) = {
builder
.addTraits(x.getAllTraits().asScala.map(_._2).asJavaCollection)

Expand All @@ -838,6 +837,13 @@ private[codegen] class SmithyToIR(
.accept(this)
}

model.getShape(x.getTarget()).asScala.flatMap { shape =>
val builder =
(Shape.shapeToBuilder(shape: Shape): AbstractShapeBuilder[_, _])
processBuilder(builder)
}
}

def timestampShape(x: TimestampShape): Option[Type] =
primitive(x, "smithy.api#Timestamp", Primitive.Timestamp)

Expand Down Expand Up @@ -1398,7 +1404,7 @@ private[codegen] class SmithyToIR(
case (node, IdRefCase()) =>
val ref = Type.Ref("smithy4s", "ShapeId")
val namespace :: name :: _ =
node.asStringNode.get.getValue.split("#").toList
node.asStringNode.get.getValue.split("#").toList: @unchecked
def toField(value: String) = TypedNode.FieldTN.RequiredTN(
NodeAndType(
Node.from(value),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,14 @@ class ValidatedNewtypesTransformer extends ProjectionTransformer {
)
}

private def processShape(shape: Shape, lookup: String => Boolean) =
private def processShape[S <: Shape, B <: AbstractShapeBuilder[B, S]](
shape: S,
lookup: String => Boolean
) =
if (lookup(shape.getId().getNamespace()))
shape match {
case ValidatedNewtypesTransformer.SupportedShape(s) =>
addTrait(Shape.shapeToBuilder(s): AbstractShapeBuilder[_, _])
addTrait[S, B](Shape.shapeToBuilder(s): AbstractShapeBuilder[B, S])
case _ => shape
}
else
Expand All @@ -87,11 +90,11 @@ object ValidatedNewtypesTransformer {
private val METADATA_KEY = "smithy4sRenderValidatedNewtypes"

object SupportedShape {
def unapply(shape: Shape): Option[Shape] = shape match {
def unapply[S <: Shape](shape: S): Option[S] = shape match {
case _ if shape.hasTrait(classOf[UnwrapTrait]) => None
case _ if shape.hasTrait(classOf[ValidateNewtypeTrait]) => None
case s: StringShape if hasStringConstraints(s) => Some(s)
case n: NumberShape if hasNumberConstraints(n) => Some(n)
case s: StringShape if hasStringConstraints(s) => Some(shape)
case n: NumberShape if hasNumberConstraints(n) => Some(shape)
case _ => None
}

Expand Down
6 changes: 6 additions & 0 deletions project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@ import org.portablescala.sbtplatformdeps.PlatformDepsPlugin.autoImport._

object Dependencies {

lazy val coursier_2 = "io.get-coursier" %% "coursier" % "2.1.24"
lazy val coursier_3 =
("io.get-coursier" % "coursier" % "2.1.24" cross CrossVersion.for3Use2_13)
.exclude("org.scala-lang.modules", "scala-collection-compat_2.13")
.exclude("com.github.plokhotnyuk.jsoniter-scala", "jsoniter-scala-core_3")

val collectionsCompat =
Def.setting(
"org.scala-lang.modules" %%% "scala-collection-compat" % "2.11.0"
Expand Down
Loading