1// Copyright (C) MongoDB, Inc. 2022-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7//+build gssapi,windows
8
9#include "sspi_wrapper.h"
10
11static HINSTANCE sspi_secur32_dll = NULL;
12static PSecurityFunctionTable sspi_functions = NULL;
13static const LPSTR SSPI_PACKAGE_NAME = "kerberos";
14
15int sspi_init(
16)
17{
18 // Load the secur32.dll library using its exact path. Passing the exact DLL path rather than allowing LoadLibrary to
19 // search in different locations removes the possibility of DLL preloading attacks. We use GetSystemDirectoryA and
20 // LoadLibraryA rather than the GetSystemDirectory/LoadLibrary aliases to ensure the ANSI versions are used so we
21 // don't have to account for variations in char sizes if UNICODE is enabled.
22
23 // Passing a 0 size will return the required buffer length to hold the path, including the null terminator.
24 int requiredLen = GetSystemDirectoryA(NULL, 0);
25 if (!requiredLen) {
26 return GetLastError();
27 }
28
29 // Allocate a buffer to hold the system directory + "\secur32.dll" (length 12, not including null terminator).
30 int actualLen = requiredLen + 12;
31 char *directoryBuffer = (char *) calloc(1, actualLen);
32 int directoryLen = GetSystemDirectoryA(directoryBuffer, actualLen);
33 if (!directoryLen) {
34 free(directoryBuffer);
35 return GetLastError();
36 }
37
38 // Append the DLL name to the buffer.
39 char *dllName = "\\secur32.dll";
40 strcpy_s(&(directoryBuffer[directoryLen]), actualLen - directoryLen, dllName);
41
42 sspi_secur32_dll = LoadLibraryA(directoryBuffer);
43 free(directoryBuffer);
44 if (!sspi_secur32_dll) {
45 return GetLastError();
46 }
47
48 INIT_SECURITY_INTERFACE init_security_interface = (INIT_SECURITY_INTERFACE)GetProcAddress(sspi_secur32_dll, SECURITY_ENTRYPOINT);
49 if (!init_security_interface) {
50 return -1;
51 }
52
53 sspi_functions = (*init_security_interface)();
54 if (!sspi_functions) {
55 return -2;
56 }
57
58 return SSPI_OK;
59}
60
61int sspi_client_init(
62 sspi_client_state *client,
63 char* username,
64 char* password
65)
66{
67 TimeStamp timestamp;
68
69 if (username) {
70 if (password) {
71 SEC_WINNT_AUTH_IDENTITY auth_identity;
72
73 #ifdef _UNICODE
74 auth_identity.Flags = SEC_WINNT_AUTH_IDENTITY_UNICODE;
75 #else
76 auth_identity.Flags = SEC_WINNT_AUTH_IDENTITY_ANSI;
77 #endif
78 auth_identity.User = (LPSTR) username;
79 auth_identity.UserLength = strlen(username);
80 auth_identity.Password = (LPSTR) password;
81 auth_identity.PasswordLength = strlen(password);
82 auth_identity.Domain = NULL;
83 auth_identity.DomainLength = 0;
84 client->status = sspi_functions->AcquireCredentialsHandle(NULL, SSPI_PACKAGE_NAME, SECPKG_CRED_OUTBOUND, NULL, &auth_identity, NULL, NULL, &client->cred, ×tamp);
85 } else {
86 client->status = sspi_functions->AcquireCredentialsHandle(username, SSPI_PACKAGE_NAME, SECPKG_CRED_OUTBOUND, NULL, NULL, NULL, NULL, &client->cred, ×tamp);
87 }
88 } else {
89 client->status = sspi_functions->AcquireCredentialsHandle(NULL, SSPI_PACKAGE_NAME, SECPKG_CRED_OUTBOUND, NULL, NULL, NULL, NULL, &client->cred, ×tamp);
90 }
91
92 if (client->status != SEC_E_OK) {
93 return SSPI_ERROR;
94 }
95
96 return SSPI_OK;
97}
98
99int sspi_client_username(
100 sspi_client_state *client,
101 char** username
102)
103{
104 SecPkgCredentials_Names names;
105 client->status = sspi_functions->QueryCredentialsAttributes(&client->cred, SECPKG_CRED_ATTR_NAMES, &names);
106
107 if (client->status != SEC_E_OK) {
108 return SSPI_ERROR;
109 }
110
111 int len = strlen(names.sUserName) + 1;
112 *username = malloc(len);
113 memcpy(*username, names.sUserName, len);
114
115 sspi_functions->FreeContextBuffer(names.sUserName);
116
117 return SSPI_OK;
118}
119
120int sspi_client_negotiate(
121 sspi_client_state *client,
122 char* spn,
123 PVOID input,
124 ULONG input_length,
125 PVOID* output,
126 ULONG* output_length
127)
128{
129 SecBufferDesc inbuf;
130 SecBuffer in_bufs[1];
131 SecBufferDesc outbuf;
132 SecBuffer out_bufs[1];
133
134 if (client->has_ctx > 0) {
135 inbuf.ulVersion = SECBUFFER_VERSION;
136 inbuf.cBuffers = 1;
137 inbuf.pBuffers = in_bufs;
138 in_bufs[0].pvBuffer = input;
139 in_bufs[0].cbBuffer = input_length;
140 in_bufs[0].BufferType = SECBUFFER_TOKEN;
141 }
142
143 outbuf.ulVersion = SECBUFFER_VERSION;
144 outbuf.cBuffers = 1;
145 outbuf.pBuffers = out_bufs;
146 out_bufs[0].pvBuffer = NULL;
147 out_bufs[0].cbBuffer = 0;
148 out_bufs[0].BufferType = SECBUFFER_TOKEN;
149
150 ULONG context_attr = 0;
151
152 client->status = sspi_functions->InitializeSecurityContext(
153 &client->cred,
154 client->has_ctx > 0 ? &client->ctx : NULL,
155 (LPSTR) spn,
156 ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_MUTUAL_AUTH,
157 0,
158 SECURITY_NETWORK_DREP,
159 client->has_ctx > 0 ? &inbuf : NULL,
160 0,
161 &client->ctx,
162 &outbuf,
163 &context_attr,
164 NULL);
165
166 if (client->status != SEC_E_OK && client->status != SEC_I_CONTINUE_NEEDED) {
167 return SSPI_ERROR;
168 }
169
170 client->has_ctx = 1;
171
172 *output = malloc(out_bufs[0].cbBuffer);
173 *output_length = out_bufs[0].cbBuffer;
174 memcpy(*output, out_bufs[0].pvBuffer, *output_length);
175 sspi_functions->FreeContextBuffer(out_bufs[0].pvBuffer);
176
177 if (client->status == SEC_I_CONTINUE_NEEDED) {
178 return SSPI_CONTINUE;
179 }
180
181 return SSPI_OK;
182}
183
184int sspi_client_wrap_msg(
185 sspi_client_state *client,
186 PVOID input,
187 ULONG input_length,
188 PVOID* output,
189 ULONG* output_length
190)
191{
192 SecPkgContext_Sizes sizes;
193
194 client->status = sspi_functions->QueryContextAttributes(&client->ctx, SECPKG_ATTR_SIZES, &sizes);
195 if (client->status != SEC_E_OK) {
196 return SSPI_ERROR;
197 }
198
199 char *msg = malloc((sizes.cbSecurityTrailer + input_length + sizes.cbBlockSize) * sizeof(char));
200 memcpy(&msg[sizes.cbSecurityTrailer], input, input_length);
201
202 SecBuffer wrap_bufs[3];
203 SecBufferDesc wrap_buf_desc;
204 wrap_buf_desc.cBuffers = 3;
205 wrap_buf_desc.pBuffers = wrap_bufs;
206 wrap_buf_desc.ulVersion = SECBUFFER_VERSION;
207
208 wrap_bufs[0].cbBuffer = sizes.cbSecurityTrailer;
209 wrap_bufs[0].BufferType = SECBUFFER_TOKEN;
210 wrap_bufs[0].pvBuffer = msg;
211
212 wrap_bufs[1].cbBuffer = input_length;
213 wrap_bufs[1].BufferType = SECBUFFER_DATA;
214 wrap_bufs[1].pvBuffer = msg + sizes.cbSecurityTrailer;
215
216 wrap_bufs[2].cbBuffer = sizes.cbBlockSize;
217 wrap_bufs[2].BufferType = SECBUFFER_PADDING;
218 wrap_bufs[2].pvBuffer = msg + sizes.cbSecurityTrailer + input_length;
219
220 client->status = sspi_functions->EncryptMessage(&client->ctx, SECQOP_WRAP_NO_ENCRYPT, &wrap_buf_desc, 0);
221 if (client->status != SEC_E_OK) {
222 free(msg);
223 return SSPI_ERROR;
224 }
225
226 *output_length = wrap_bufs[0].cbBuffer + wrap_bufs[1].cbBuffer + wrap_bufs[2].cbBuffer;
227 *output = malloc(*output_length);
228
229 memcpy(*output, wrap_bufs[0].pvBuffer, wrap_bufs[0].cbBuffer);
230 memcpy(*output + wrap_bufs[0].cbBuffer, wrap_bufs[1].pvBuffer, wrap_bufs[1].cbBuffer);
231 memcpy(*output + wrap_bufs[0].cbBuffer + wrap_bufs[1].cbBuffer, wrap_bufs[2].pvBuffer, wrap_bufs[2].cbBuffer);
232
233 free(msg);
234
235 return SSPI_OK;
236}
237
238int sspi_client_destroy(
239 sspi_client_state *client
240)
241{
242 if (client->has_ctx > 0) {
243 sspi_functions->DeleteSecurityContext(&client->ctx);
244 }
245
246 sspi_functions->FreeCredentialsHandle(&client->cred);
247
248 return SSPI_OK;
249}
View as plain text