...
1 package rewrite
2
3 import (
4 "bytes"
5 "fmt"
6 "go/ast"
7 "go/token"
8 "os"
9 "path/filepath"
10 "strconv"
11 "strings"
12
13 "golang.org/x/tools/go/packages"
14
15 "github.com/99designs/gqlgen/internal/code"
16 )
17
18 type Rewriter struct {
19 pkg *packages.Package
20 files map[string]string
21 copied map[ast.Decl]bool
22 }
23
24 func New(dir string) (*Rewriter, error) {
25 importPath := code.ImportPathForDir(dir)
26 if importPath == "" {
27 return nil, fmt.Errorf("import path not found for directory: %q", dir)
28 }
29 pkgs, err := packages.Load(&packages.Config{
30 Mode: packages.NeedSyntax | packages.NeedTypes,
31 }, importPath)
32 if err != nil {
33 return nil, err
34 }
35 if len(pkgs) == 0 {
36 return nil, fmt.Errorf("package not found for importPath: %s", importPath)
37 }
38
39 return &Rewriter{
40 pkg: pkgs[0],
41 files: map[string]string{},
42 copied: map[ast.Decl]bool{},
43 }, nil
44 }
45
46 func (r *Rewriter) getSource(start, end token.Pos) string {
47 startPos := r.pkg.Fset.Position(start)
48 endPos := r.pkg.Fset.Position(end)
49
50 if startPos.Filename != endPos.Filename {
51 panic("cant get source spanning multiple files")
52 }
53
54 file := r.getFile(startPos.Filename)
55 return file[startPos.Offset:endPos.Offset]
56 }
57
58 func (r *Rewriter) getFile(filename string) string {
59 if _, ok := r.files[filename]; !ok {
60 b, err := os.ReadFile(filename)
61 if err != nil {
62 panic(fmt.Errorf("unable to load file, already exists: %w", err))
63 }
64
65 r.files[filename] = string(b)
66
67 }
68
69 return r.files[filename]
70 }
71
72 func (r *Rewriter) GetPrevDecl(structname, methodname string) *ast.FuncDecl {
73 for _, f := range r.pkg.Syntax {
74 for _, d := range f.Decls {
75 d, isFunc := d.(*ast.FuncDecl)
76 if !isFunc {
77 continue
78 }
79 if d.Name.Name != methodname {
80 continue
81 }
82 if d.Recv == nil || len(d.Recv.List) == 0 {
83 continue
84 }
85 recv := d.Recv.List[0].Type
86 if star, isStar := recv.(*ast.StarExpr); isStar {
87 recv = star.X
88 }
89 ident, ok := recv.(*ast.Ident)
90 if !ok {
91 continue
92 }
93 if ident.Name != structname {
94 continue
95 }
96 r.copied[d] = true
97 return d
98 }
99 }
100 return nil
101 }
102
103 func (r *Rewriter) GetMethodComment(structname, methodname string) string {
104 d := r.GetPrevDecl(structname, methodname)
105 if d != nil {
106 return d.Doc.Text()
107 }
108 return ""
109 }
110
111 func (r *Rewriter) GetMethodBody(structname, methodname string) string {
112 d := r.GetPrevDecl(structname, methodname)
113 if d != nil {
114 return r.getSource(d.Body.Pos()+1, d.Body.End()-1)
115 }
116 return ""
117 }
118
119 func (r *Rewriter) MarkStructCopied(name string) {
120 for _, f := range r.pkg.Syntax {
121 for _, d := range f.Decls {
122 d, isGen := d.(*ast.GenDecl)
123 if !isGen {
124 continue
125 }
126 if d.Tok != token.TYPE || len(d.Specs) == 0 {
127 continue
128 }
129
130 spec, isTypeSpec := d.Specs[0].(*ast.TypeSpec)
131 if !isTypeSpec {
132 continue
133 }
134
135 if spec.Name.Name != name {
136 continue
137 }
138
139 r.copied[d] = true
140 }
141 }
142 }
143
144 func (r *Rewriter) ExistingImports(filename string) []Import {
145 filename, err := filepath.Abs(filename)
146 if err != nil {
147 panic(err)
148 }
149 for _, f := range r.pkg.Syntax {
150 pos := r.pkg.Fset.Position(f.Pos())
151
152 if filename != pos.Filename {
153 continue
154 }
155
156 var imps []Import
157 for _, i := range f.Imports {
158 name := ""
159 if i.Name != nil {
160 name = i.Name.Name
161 }
162 path, err := strconv.Unquote(i.Path.Value)
163 if err != nil {
164 panic(err)
165 }
166 imps = append(imps, Import{name, path})
167 }
168 return imps
169 }
170 return nil
171 }
172
173 func (r *Rewriter) RemainingSource(filename string) string {
174 filename, err := filepath.Abs(filename)
175 if err != nil {
176 panic(err)
177 }
178 for _, f := range r.pkg.Syntax {
179 pos := r.pkg.Fset.Position(f.Pos())
180
181 if filename != pos.Filename {
182 continue
183 }
184
185 var buf bytes.Buffer
186
187 for _, d := range f.Decls {
188 if r.copied[d] {
189 continue
190 }
191
192 if d, isGen := d.(*ast.GenDecl); isGen && d.Tok == token.IMPORT {
193 continue
194 }
195
196 buf.WriteString(r.getSource(d.Pos(), d.End()))
197 buf.WriteString("\n")
198 }
199
200 return strings.TrimSpace(buf.String())
201 }
202 return ""
203 }
204
205 type Import struct {
206 Alias string
207 ImportPath string
208 }
209
View as plain text