1
16
17 package dra
18
19 import (
20 "fmt"
21 "sync"
22
23 resourcev1alpha2 "k8s.io/api/resource/v1alpha2"
24 "k8s.io/apimachinery/pkg/types"
25 "k8s.io/apimachinery/pkg/util/sets"
26 "k8s.io/kubernetes/pkg/kubelet/cm/dra/state"
27 "k8s.io/kubernetes/pkg/kubelet/cm/util/cdi"
28 kubecontainer "k8s.io/kubernetes/pkg/kubelet/container"
29 )
30
31
32
33 type ClaimInfo struct {
34 sync.RWMutex
35 state.ClaimInfoState
36
37
38 annotations map[string][]kubecontainer.Annotation
39 prepared bool
40 }
41
42 func (info *ClaimInfo) addPodReference(podUID types.UID) {
43 info.Lock()
44 defer info.Unlock()
45
46 info.PodUIDs.Insert(string(podUID))
47 }
48
49 func (info *ClaimInfo) deletePodReference(podUID types.UID) {
50 info.Lock()
51 defer info.Unlock()
52
53 info.PodUIDs.Delete(string(podUID))
54 }
55
56 func (info *ClaimInfo) addCDIDevices(pluginName string, cdiDevices []string) error {
57 info.Lock()
58 defer info.Unlock()
59
60
61
62
63 annotations, err := cdi.GenerateAnnotations(info.ClaimUID, info.DriverName, cdiDevices)
64 if err != nil {
65 return fmt.Errorf("failed to generate container annotations, err: %+v", err)
66 }
67
68 if info.CDIDevices == nil {
69 info.CDIDevices = make(map[string][]string)
70 }
71
72 info.CDIDevices[pluginName] = cdiDevices
73 info.annotations[pluginName] = annotations
74
75 return nil
76 }
77
78
79 func (info *ClaimInfo) annotationsAsList() []kubecontainer.Annotation {
80 info.RLock()
81 defer info.RUnlock()
82
83 var lst []kubecontainer.Annotation
84 for _, v := range info.annotations {
85 lst = append(lst, v...)
86 }
87 return lst
88 }
89
90
91 type claimInfoCache struct {
92 sync.RWMutex
93 state state.CheckpointState
94 claimInfo map[string]*ClaimInfo
95 }
96
97 func newClaimInfo(driverName, className string, claimUID types.UID, claimName, namespace string, podUIDs sets.Set[string], resourceHandles []resourcev1alpha2.ResourceHandle) *ClaimInfo {
98 claimInfoState := state.ClaimInfoState{
99 DriverName: driverName,
100 ClassName: className,
101 ClaimUID: claimUID,
102 ClaimName: claimName,
103 Namespace: namespace,
104 PodUIDs: podUIDs,
105 ResourceHandles: resourceHandles,
106 }
107 claimInfo := ClaimInfo{
108 ClaimInfoState: claimInfoState,
109 annotations: make(map[string][]kubecontainer.Annotation),
110 }
111 return &claimInfo
112 }
113
114
115 func newClaimInfoFromResourceClaim(resourceClaim *resourcev1alpha2.ResourceClaim) *ClaimInfo {
116
117
118
119
120 resourceHandles := resourceClaim.Status.Allocation.ResourceHandles
121 if len(resourceHandles) == 0 {
122 resourceHandles = make([]resourcev1alpha2.ResourceHandle, 1)
123 }
124
125 return newClaimInfo(
126 resourceClaim.Status.DriverName,
127 resourceClaim.Spec.ResourceClassName,
128 resourceClaim.UID,
129 resourceClaim.Name,
130 resourceClaim.Namespace,
131 make(sets.Set[string]),
132 resourceHandles,
133 )
134 }
135
136
137 func newClaimInfoCache(stateDir, checkpointName string) (*claimInfoCache, error) {
138 stateImpl, err := state.NewCheckpointState(stateDir, checkpointName)
139 if err != nil {
140 return nil, fmt.Errorf("could not initialize checkpoint manager, please drain node and remove dra state file, err: %+v", err)
141 }
142
143 curState, err := stateImpl.GetOrCreate()
144 if err != nil {
145 return nil, fmt.Errorf("error calling GetOrCreate() on checkpoint state: %v", err)
146 }
147
148 cache := &claimInfoCache{
149 state: stateImpl,
150 claimInfo: make(map[string]*ClaimInfo),
151 }
152
153 for _, entry := range curState {
154 info := newClaimInfo(
155 entry.DriverName,
156 entry.ClassName,
157 entry.ClaimUID,
158 entry.ClaimName,
159 entry.Namespace,
160 entry.PodUIDs,
161 entry.ResourceHandles,
162 )
163 for pluginName, cdiDevices := range entry.CDIDevices {
164 err := info.addCDIDevices(pluginName, cdiDevices)
165 if err != nil {
166 return nil, fmt.Errorf("failed to add CDIDevices to claimInfo %+v: %+v", info, err)
167 }
168 }
169 cache.add(info)
170 }
171
172 return cache, nil
173 }
174
175 func (cache *claimInfoCache) add(res *ClaimInfo) {
176 cache.Lock()
177 defer cache.Unlock()
178
179 cache.claimInfo[res.ClaimName+res.Namespace] = res
180 }
181
182 func (cache *claimInfoCache) get(claimName, namespace string) *ClaimInfo {
183 cache.RLock()
184 defer cache.RUnlock()
185
186 return cache.claimInfo[claimName+namespace]
187 }
188
189 func (cache *claimInfoCache) delete(claimName, namespace string) {
190 cache.Lock()
191 defer cache.Unlock()
192
193 delete(cache.claimInfo, claimName+namespace)
194 }
195
196
197
198
199
200 func (cache *claimInfoCache) hasPodReference(UID types.UID) bool {
201 cache.RLock()
202 defer cache.RUnlock()
203
204 for _, claimInfo := range cache.claimInfo {
205 if claimInfo.PodUIDs.Has(string(UID)) {
206 return true
207 }
208 }
209
210 return false
211 }
212
213 func (cache *claimInfoCache) syncToCheckpoint() error {
214 cache.RLock()
215 defer cache.RUnlock()
216
217 claimInfoStateList := make(state.ClaimInfoStateList, 0, len(cache.claimInfo))
218 for _, infoClaim := range cache.claimInfo {
219 claimInfoStateList = append(claimInfoStateList, infoClaim.ClaimInfoState)
220 }
221
222 return cache.state.Store(claimInfoStateList)
223 }
224
View as plain text