scalalandio / chimney
1
package io.scalaland.chimney.internal.utils
2

3
import scala.reflect.macros.blackbox
4

5
trait MacroUtils extends CompanionUtils {
6

7
  val c: blackbox.Context
8

9
  import c.universe._
10

11
  implicit class NameOps(n: Name) {
12 4
    def toNameConstant: Constant = Constant(n.decodedName.toString)
13 4
    def toNameLiteral: Literal = Literal(toNameConstant)
14 4
    def toSingletonTpe: ConstantType = c.internal.constantType(toNameConstant)
15 4
    def toCanonicalName: String = n.toString
16
  }
17

18
  type TypeConstructorTag[F[_]] = WeakTypeTag[F[Unit]]
19

20
  object TypeConstructorTag {
21
    def apply[F[_]: TypeConstructorTag]: Type = {
22 4
      weakTypeOf[F[Unit]].typeConstructor
23
    }
24
  }
25

26
  implicit class TypeOps(t: Type) {
27

28
    def applyTypeArg(arg: Type): Type = {
29 4
      val ee = t.etaExpand
30 4
      if (ee.typeParams.size != 1) {
31
        // $COVERAGE-OFF$
32
        c.abort(c.enclosingPosition, s"Type $ee must have single type parameter!")
33
        // $COVERAGE-ON$
34
      }
35 4
      ee.finalResultType.substituteTypes(ee.typeParams, List(arg))
36
    }
37

38
    def applyTypeArgs(args: Type*): Type = {
39 4
      val ee = t.etaExpand
40 4
      if (ee.typeParams.size != args.size) {
41
        // $COVERAGE-OFF$
42
        val een = ee.typeParams.size
43
        val argsn = args.size
44
        c.abort(c.enclosingPosition, s"Type $ee has different arity ($een) than applied to applyTypeArgs ($argsn)!")
45
        // $COVERAGE-ON$
46
      }
47 4
      ee.finalResultType.substituteTypes(ee.typeParams, args.toList)
48
    }
49

50
    def isValueClass: Boolean =
51 4
      t <:< typeOf[AnyVal] && !primitives.exists(_ =:= t)
52

53
    def isCaseClass: Boolean =
54 4
      t.typeSymbol.isCaseClass
55

56
    def isSealedClass: Boolean =
57 4
      t.typeSymbol.classSymbolOpt.exists(_.isSealed)
58

59
    def caseClassParams: Seq[MethodSymbol] = {
60 4
      t.decls.collect {
61 4
        case m: MethodSymbol if m.isCaseAccessor || (isValueClass && m.isParamAccessor) =>
62 4
          m.asMethod
63 4
      }.toSeq
64
    }
65

66
    def getterMethods: Seq[MethodSymbol] = {
67 4
      t.decls.collect {
68 4
        case m: MethodSymbol if m.isPublic && (m.isGetter || m.isParameterless) =>
69
          m
70 4
      }.toSeq
71
    }
72

73
    def beanSetterMethods: Seq[MethodSymbol] = {
74 4
      t.members.collect { case m: MethodSymbol if m.isBeanSetter => m }.toSeq
75
    }
76

77
    def valueClassMember: Option[MethodSymbol] = {
78 4
      t.decls.collectFirst {
79 4
        case m: MethodSymbol if m.isParamAccessor =>
80 4
          m.asMethod
81
      }
82
    }
83

84
    def singletonString: String = {
85
      t.asInstanceOf[scala.reflect.internal.Types#UniqueConstantType]
86
        .value
87
        .value
88 4
        .asInstanceOf[String]
89
    }
90

91
    def collectionInnerTpe: Type = {
92 4
      t.typeArgs match {
93
        case List(unaryInnerT) => unaryInnerT
94
        case List(innerT1, innerT2) =>
95 4
          c.typecheck(tq"($innerT1, $innerT2)", c.TYPEmode).tpe
96
        // $COVERAGE-OFF$
97
        case Nil =>
98
          c.abort(c.enclosingPosition, "Collection type must have type parameters!")
99
        case _ =>
100
          c.abort(c.enclosingPosition, "Collection types with more than 2 type arguments are not supported!")
101
        // $COVERAGE-ON$
102
      }
103
    }
104

105 4
    def coproductSymbol: Symbol = t match {
106
      case c.universe.ConstantType(Constant(enumeration: TermSymbol)) => enumeration
107 4
      case _                                                          => t.typeSymbol
108
    }
109
  }
110

111
  implicit class SymbolOps(s: Symbol) {
112

113
    def classSymbolOpt: Option[ClassSymbol] =
114 4
      if (s.isClass) Some(s.asClass) else None
115

116
    def isCaseClass: Boolean =
117 4
      classSymbolOpt.exists(_.isCaseClass)
118

119
    lazy val caseClassDefaults: Map[String, Tree] = {
120
      s.typeSignature
121
      classSymbolOpt
122
        .flatMap { classSymbol =>
123
          val classType = classSymbol.toType
124
          val companionSym = companionSymbol(classType)
125
          val primaryFactoryMethod = companionSym.asModule.info.decl(TermName("apply")).alternatives.lastOption
126
          primaryFactoryMethod.foreach(_.asMethod.typeSignature)
127
          val primaryConstructor = classSymbol.primaryConstructor
128
          val headParamListOpt = primaryConstructor.asMethod.typeSignature.paramLists.headOption.map(_.map(_.asTerm))
129

130
          headParamListOpt.map { headParamList =>
131
            headParamList.zipWithIndex.flatMap {
132
              case (param, idx) =>
133
                if (param.isParamWithDefault) {
134
                  val method = TermName("apply$default$" + (idx + 1))
135
                  Some(param.name.toString -> q"$companionSym.$method")
136
                } else {
137
                  None
138
                }
139
            }.toMap
140
          }
141
        }
142
        .getOrElse {
143
          // $COVERAGE-OFF$
144
          Map.empty
145
          // $COVERAGE-ON$
146
        }
147
    }
148

149
    def typeInSealedParent(parentTpe: Type): Type = {
150 4
      s.typeSignature // Workaround for <https://issues.scala-lang.org/browse/SI-7755>
151

152 4
      if (s.isJavaEnum) {
153 4
        s.typeSignature
154 4
      } else {
155 4
        val sEta = s.asType.toType.etaExpand
156 4
        sEta.finalResultType.substituteTypes(
157 4
          sEta.baseType(parentTpe.typeSymbol).typeArgs.map(_.typeSymbol),
158 4
          parentTpe.typeArgs
159
        )
160
      }
161
    }
162
  }
163

164
  implicit class MethodSymbolOps(ms: MethodSymbol) {
165

166
    def canonicalName: String = {
167 4
      val name = ms.name.decodedName.toString
168 4
      if (isBeanSetter) {
169 4
        val stripedPrefix = name.drop(3)
170 4
        val lowerizedName = stripedPrefix.toCharArray
171 4
        lowerizedName(0) = lowerizedName(0).toLower
172 4
        new String(lowerizedName)
173
      } else {
174 4
        name
175
      }
176
    }
177

178
    def isBeanSetter: Boolean = {
179
      ms.isPublic &&
180 4
      ms.name.decodedName.toString.startsWith("set") &&
181 4
      ms.name.decodedName.toString.lengthCompare(3) > 0 &&
182 4
      ms.paramLists.lengthCompare(1) == 0 &&
183 4
      ms.paramLists.head.lengthCompare(1) == 0 &&
184 4
      ms.returnType == typeOf[Unit]
185
    }
186

187
    def resultTypeIn(site: Type): Type = {
188 4
      ms.typeSignatureIn(site).finalResultType
189
    }
190

191
    def beanSetterParamTypeIn(site: Type): Type = {
192 4
      ms.paramLists.head.head.typeSignatureIn(site)
193
    }
194

195
    def isParameterless: Boolean = {
196 4
      ms.paramLists.isEmpty || ms.paramLists == List(List())
197
    }
198
  }
199

200
  implicit class ClassSymbolOps(cs: ClassSymbol) {
201

202
    def subclasses: List[Symbol] =
203 4
      cs.knownDirectSubclasses.toList.flatMap { subclass =>
204 4
        val asClass = subclass.asClass
205 4
        if (asClass.isTrait && asClass.isSealed) {
206 4
          asClass.subclasses
207
        } else {
208 4
          List(subclass)
209
        }
210
      }
211
  }
212

213
  // $COVERAGE-OFF$
214
  implicit class TreeOps(t: Tree) {
215

216
    def debug: Tree = {
217
      println("TREE: " + t)
218
      println("RAW:  " + showRaw(t))
219
      t
220
    }
221

222
    def extractBlock: (List[Tree], Tree) = t match {
223
      case Typed(tt, _) =>
224
        tt.extractBlock
225
      case Block(stats, expr) =>
226
        (stats, expr)
227
      case other =>
228
        (Nil, other)
229
    }
230

231
    def extractStats: List[Tree] = t match {
232
      case Typed(tt, _) =>
233
        tt.extractStats
234
      case Block(stats, _) =>
235
        stats
236
      case _ =>
237
        Nil
238
    }
239

240
    def insertToBlock(tree: Tree): Tree = {
241
      val (stats, expr) = t.extractBlock
242
      Block(stats :+ tree, expr)
243
    }
244

245
    def extractSelectorFieldName: Name = {
246
      extractSelectorFieldNameOpt.getOrElse {
247
        c.abort(c.enclosingPosition, "Invalid selector!")
248
      }
249
    }
250

251
    def extractSelectorFieldNameOpt: Option[Name] = {
252
      t match {
253
        case q"(${vd: ValDef}) => ${idt: Ident}.${fieldName: Name}" if vd.name == idt.name =>
254
          Some(fieldName)
255
        case _ =>
256
          None
257
      }
258
    }
259

260
    def convertCollection(TargetTpe: Type, InnerTpe: Type): Tree = {
261
      if (TargetTpe <:< typeOf[scala.collection.Map[_, _]] && scala.util.Properties.versionNumberString < "2.13") {
262
        q"$t.toMap"
263
      } else {
264
        q"$t.to(_root_.scala.Predef.implicitly[_root_.scala.collection.compat.Factory[$InnerTpe, $TargetTpe]])"
265
      }
266
    }
267

268
    def callTransform(input: Tree): Tree = {
269
      q"$t.transform($input)"
270
    }
271
  }
272
  // $COVERAGE-ON$
273

274
  implicit class TransformerDefinitionTreeOps(td: Tree) {
275

276
    def accessConst(name: String, targetTpe: Type): Tree = {
277
      q"""
278
        $td
279
          .overrides($name)
280
          .asInstanceOf[$targetTpe]
281
      """
282
    }
283

284
    def accessComputed(name: String, srcPrefixTree: Tree, fromTpe: Type, targetTpe: Type): Tree = {
285
      q"""
286
        $td
287
          .overrides($name)
288
          .asInstanceOf[$fromTpe => $targetTpe]
289
          .apply($srcPrefixTree)
290
      """
291
    }
292

293
    def addOverride(fieldName: Name, value: Tree): Tree = {
294
      q"$td.__addOverride(${fieldName.toNameLiteral}, $value)"
295
    }
296

297
    def addInstance(fullInstName: String, fullTargetName: String, f: Tree): Tree = {
298
      q"$td.__addInstance($fullInstName, $fullTargetName, $f)"
299
    }
300

301
    def refineConfig(cfgTpe: Type): Tree = {
302
      q"$td.__refineConfig[$cfgTpe]"
303
    }
304

305
    def refineTransformerDefinition(definitionRefinementFn: Tree) = {
306
      q"$td.__refineTransformerDefinition($definitionRefinementFn)"
307
    }
308

309
    def refineTransformerDefinition_Hack(
310
        definitionRefinementFn: Map[String, Tree] => Tree,
311
        valTree: (String, Tree)
312
    ): Tree = {
313
      // normally, we would like to use refineTransformerDefinition, which works well on Scala 2.11
314
      // in few cases on Scala 2.12+ it ends up as 'Error while emitting XXX.scala' compiler error
315
      // with this hack, we can work around scalac bugs
316

317
      val (name, tree) = valTree
318
      val fnTermName = TermName(c.freshName(name))
319
      val fnMapTree = Map(name -> Ident(fnTermName))
320
      q"""
321
        {
322
          val ${fnTermName} = $tree
323
          $td.__refineTransformerDefinition(${definitionRefinementFn(fnMapTree)})
324
        }
325
      """
326
    }
327
  }
328

329 4
  private val primitives = Set(
330 4
    typeOf[Double],
331 4
    typeOf[Float],
332 4
    typeOf[Short],
333 4
    typeOf[Byte],
334 4
    typeOf[Int],
335 4
    typeOf[Long],
336 4
    typeOf[Char],
337 4
    typeOf[Boolean],
338 4
    typeOf[Unit]
339
  )
340
}

Read our documentation on viewing source code .

Loading