|
1 | 1 | import ts from "typescript"; |
| 2 | +import { traverseTypeHierarchy } from "./type-traversal.ts"; |
2 | 3 |
|
3 | 4 | export type CellBrand = |
4 | 5 | | "opaque" |
@@ -56,50 +57,13 @@ function findCellBrandSymbol( |
56 | 57 | checker: ts.TypeChecker, |
57 | 58 | seen: Set<ts.Type>, |
58 | 59 | ): ts.Symbol | undefined { |
59 | | - if (seen.has(type)) return undefined; |
60 | | - seen.add(type); |
61 | | - |
62 | | - const direct = getBrandSymbolFromType(type, checker); |
63 | | - if (direct) return direct; |
64 | | - |
65 | | - const apparent = checker.getApparentType(type); |
66 | | - if (apparent !== type) { |
67 | | - const fromApparent = findCellBrandSymbol(apparent, checker, seen); |
68 | | - if (fromApparent) return fromApparent; |
69 | | - } |
70 | | - |
71 | | - if (type.flags & (ts.TypeFlags.Union | ts.TypeFlags.Intersection)) { |
72 | | - const compound = type as ts.UnionOrIntersectionType; |
73 | | - for (const child of compound.types) { |
74 | | - const childSymbol = findCellBrandSymbol(child, checker, seen); |
75 | | - if (childSymbol) return childSymbol; |
76 | | - } |
77 | | - } |
78 | | - |
79 | | - if (!(type.flags & ts.TypeFlags.Object)) { |
80 | | - return undefined; |
81 | | - } |
82 | | - |
83 | | - const objectType = type as ts.ObjectType; |
84 | | - |
85 | | - if (objectType.objectFlags & ts.ObjectFlags.Reference) { |
86 | | - const typeRef = objectType as ts.TypeReference; |
87 | | - if (typeRef.target) { |
88 | | - const fromTarget = findCellBrandSymbol(typeRef.target, checker, seen); |
89 | | - if (fromTarget) return fromTarget; |
90 | | - } |
91 | | - } |
92 | | - |
93 | | - if (objectType.objectFlags & ts.ObjectFlags.ClassOrInterface) { |
94 | | - const baseTypes = checker.getBaseTypes(objectType as ts.InterfaceType) ?? |
95 | | - []; |
96 | | - for (const base of baseTypes) { |
97 | | - const fromBase = findCellBrandSymbol(base, checker, seen); |
98 | | - if (fromBase) return fromBase; |
99 | | - } |
100 | | - } |
101 | | - |
102 | | - return undefined; |
| 60 | + return traverseTypeHierarchy(type, { |
| 61 | + checker, |
| 62 | + checkType: (t) => getBrandSymbolFromType(t, checker), |
| 63 | + visitApparentType: true, |
| 64 | + visitTypeReferenceTarget: true, |
| 65 | + visitBaseTypes: true, |
| 66 | + }, seen); |
103 | 67 | } |
104 | 68 |
|
105 | 69 | export function getCellBrand( |
@@ -176,38 +140,23 @@ function extractWrapperTypeReference( |
176 | 140 | checker: ts.TypeChecker, |
177 | 141 | seen: Set<ts.Type>, |
178 | 142 | ): ts.TypeReference | undefined { |
179 | | - if (seen.has(type)) return undefined; |
180 | | - seen.add(type); |
181 | | - |
182 | | - if (type.flags & ts.TypeFlags.Object) { |
183 | | - const objectType = type as ts.ObjectType; |
184 | | - if (objectType.objectFlags & ts.ObjectFlags.Reference) { |
185 | | - const typeRef = objectType as ts.TypeReference; |
186 | | - const typeArgs = typeRef.typeArguments ?? |
187 | | - checker.getTypeArguments(typeRef); |
188 | | - if (typeArgs && typeArgs.length > 0) { |
189 | | - return typeRef; |
| 143 | + return traverseTypeHierarchy(type, { |
| 144 | + checker, |
| 145 | + checkType: (t) => { |
| 146 | + if (t.flags & ts.TypeFlags.Object) { |
| 147 | + const objectType = t as ts.ObjectType; |
| 148 | + if (objectType.objectFlags & ts.ObjectFlags.Reference) { |
| 149 | + const typeRef = objectType as ts.TypeReference; |
| 150 | + const typeArgs = typeRef.typeArguments ?? |
| 151 | + checker.getTypeArguments(typeRef); |
| 152 | + if (typeArgs && typeArgs.length > 0) { |
| 153 | + return typeRef; |
| 154 | + } |
| 155 | + } |
190 | 156 | } |
191 | | - } |
192 | | - } |
193 | | - |
194 | | - if (type.flags & ts.TypeFlags.Intersection) { |
195 | | - const intersectionType = type as ts.IntersectionType; |
196 | | - for (const constituent of intersectionType.types) { |
197 | | - const ref = extractWrapperTypeReference(constituent, checker, seen); |
198 | | - if (ref) return ref; |
199 | | - } |
200 | | - } |
201 | | - |
202 | | - if (type.flags & ts.TypeFlags.Union) { |
203 | | - const unionType = type as ts.UnionType; |
204 | | - for (const member of unionType.types) { |
205 | | - const ref = extractWrapperTypeReference(member, checker, seen); |
206 | | - if (ref) return ref; |
207 | | - } |
208 | | - } |
209 | | - |
210 | | - return undefined; |
| 157 | + return undefined; |
| 158 | + }, |
| 159 | + }, seen); |
211 | 160 | } |
212 | 161 |
|
213 | 162 | export function getCellWrapperInfo( |
|
0 commit comments