update for HEAD-2003021201
[reactos.git] / lib / msafd / misc / helpers.c
1 /*
2  * COPYRIGHT:   See COPYING in the top level directory
3  * PROJECT:     ReactOS Ancillary Function Driver DLL
4  * FILE:        misc/helpers.c
5  * PURPOSE:     Helper DLL management
6  * PROGRAMMERS: Casper S. Hornstrup (chorns@users.sourceforge.net)
7  * REVISIONS:
8  *   CSH 01/09-2000 Created
9  */
10 #include <msafd.h>
11 #include <helpers.h>
12
13 CRITICAL_SECTION HelperDLLDatabaseLock;
14 LIST_ENTRY HelperDLLDatabaseListHead;
15
16 PWSHELPER_DLL CreateHelperDLL(
17     LPWSTR LibraryName)
18 {
19     PWSHELPER_DLL HelperDLL;
20
21     HelperDLL = HeapAlloc(GlobalHeap, 0, sizeof(WSHELPER_DLL));
22     if (!HelperDLL)
23         return NULL;
24
25     InitializeCriticalSection(&HelperDLL->Lock);
26     HelperDLL->hModule = NULL;
27     lstrcpyW(HelperDLL->LibraryName, LibraryName);
28     HelperDLL->Mapping = NULL;
29
30     EnterCriticalSection(&HelperDLLDatabaseLock);
31     InsertTailList(&HelperDLLDatabaseListHead, &HelperDLL->ListEntry);
32     LeaveCriticalSection(&HelperDLLDatabaseLock);
33
34     AFD_DbgPrint(MAX_TRACE, ("Returning helper at (0x%X).\n", HelperDLL));
35
36     return HelperDLL;
37 }
38
39
40 INT DestroyHelperDLL(
41     PWSHELPER_DLL HelperDLL)
42 {
43     INT Status;
44
45     AFD_DbgPrint(MAX_TRACE, ("HelperDLL (0x%X).\n", HelperDLL));
46
47     EnterCriticalSection(&HelperDLLDatabaseLock);
48     RemoveEntryList(&HelperDLL->ListEntry);
49     LeaveCriticalSection(&HelperDLLDatabaseLock);
50
51     if (HelperDLL->hModule) {
52         Status = UnloadHelperDLL(HelperDLL);
53     } else {
54         Status = NO_ERROR;
55     }
56
57     if (HelperDLL->Mapping)
58         HeapFree(GlobalHeap, 0, HelperDLL->Mapping);
59
60     DeleteCriticalSection(&HelperDLL->Lock);
61
62     HeapFree(GlobalHeap, 0, HelperDLL);
63
64     return Status;
65 }
66
67
68 PWSHELPER_DLL LocateHelperDLL(
69     LPWSAPROTOCOL_INFOW lpProtocolInfo)
70 {
71     PLIST_ENTRY CurrentEntry;
72     PWSHELPER_DLL HelperDLL;
73     UINT i;
74
75     EnterCriticalSection(&HelperDLLDatabaseLock);
76     CurrentEntry = HelperDLLDatabaseListHead.Flink;
77     while (CurrentEntry != &HelperDLLDatabaseListHead) {
78             HelperDLL = CONTAINING_RECORD(CurrentEntry,
79                                       WSHELPER_DLL,
80                                       ListEntry);
81
82         for (i = 0; i < HelperDLL->Mapping->Rows; i++) {
83             if ((lpProtocolInfo->iAddressFamily == (INT) HelperDLL->Mapping->Mapping[i].AddressFamily) &&
84                 (lpProtocolInfo->iSocketType    == (INT) HelperDLL->Mapping->Mapping[i].SocketType) &&
85                 ((lpProtocolInfo->iProtocol     == (INT) HelperDLL->Mapping->Mapping[i].Protocol) ||
86                 (lpProtocolInfo->iSocketType    == SOCK_RAW))) {
87                 LeaveCriticalSection(&HelperDLLDatabaseLock);
88                 AFD_DbgPrint(MAX_TRACE, ("Returning helper DLL at (0x%X).\n", HelperDLL));
89                 return HelperDLL;
90             }
91         }
92
93         CurrentEntry = CurrentEntry->Flink;
94     }
95     LeaveCriticalSection(&HelperDLLDatabaseLock);
96
97     AFD_DbgPrint(MAX_TRACE, ("Could not locate helper DLL.\n"));
98
99     return NULL;
100 }
101
102
103 INT GetHelperDLLEntries(
104     PWSHELPER_DLL HelperDLL)
105 {
106     PVOID e;
107
108     e = GetProcAddress(HelperDLL->hModule, "WSHAddressToString");
109     if (!e) return ERROR_BAD_PROVIDER;
110         ((PVOID) HelperDLL->EntryTable.lpWSHAddressToString) = e;
111
112     e = GetProcAddress(HelperDLL->hModule, "WSHEnumProtocols");
113     if (!e) return ERROR_BAD_PROVIDER;
114         ((PVOID) HelperDLL->EntryTable.lpWSHEnumProtocols) = e;
115
116     e = GetProcAddress(HelperDLL->hModule, "WSHGetBroadcastSockaddr");
117     if (!e) return ERROR_BAD_PROVIDER;
118         ((PVOID) HelperDLL->EntryTable.lpWSHGetBroadcastSockaddr) = e;
119
120     e = GetProcAddress(HelperDLL->hModule, "WSHGetProviderGuid");
121     if (!e) return ERROR_BAD_PROVIDER;
122         ((PVOID) HelperDLL->EntryTable.lpWSHGetProviderGuid) = e;
123
124         e = GetProcAddress(HelperDLL->hModule, "WSHGetSockaddrType");
125     if (!e) return ERROR_BAD_PROVIDER;
126         ((PVOID) HelperDLL->EntryTable.lpWSHGetSockaddrType) = e;
127
128     e = GetProcAddress(HelperDLL->hModule, "WSHGetSocketInformation");
129     if (!e) return ERROR_BAD_PROVIDER;
130         ((PVOID) HelperDLL->EntryTable.lpWSHGetSocketInformation) = e;
131
132     e = GetProcAddress(HelperDLL->hModule, "WSHGetWildcardSockaddr");
133     if (!e) return ERROR_BAD_PROVIDER;
134         ((PVOID) HelperDLL->EntryTable.lpWSHGetWildcardSockaddr) = e;
135
136     e = GetProcAddress(HelperDLL->hModule, "WSHGetWinsockMapping");
137     if (!e) return ERROR_BAD_PROVIDER;
138         ((PVOID) HelperDLL->EntryTable.lpWSHGetWinsockMapping) = e;
139
140     e = GetProcAddress(HelperDLL->hModule, "WSHGetWSAProtocolInfo");
141     if (!e) return ERROR_BAD_PROVIDER;
142         ((PVOID) HelperDLL->EntryTable.lpWSHGetWSAProtocolInfo) = e;
143
144     e = GetProcAddress(HelperDLL->hModule, "WSHIoctl");
145     if (!e) return ERROR_BAD_PROVIDER;
146         ((PVOID) HelperDLL->EntryTable.lpWSHIoctl) = e;
147
148     e = GetProcAddress(HelperDLL->hModule, "WSHJoinLeaf");
149     if (!e) return ERROR_BAD_PROVIDER;
150         ((PVOID) HelperDLL->EntryTable.lpWSHJoinLeaf) = e;
151
152     e = GetProcAddress(HelperDLL->hModule, "WSHNotify");
153     if (!e) return ERROR_BAD_PROVIDER;
154         ((PVOID) HelperDLL->EntryTable.lpWSHNotify) = e;
155
156     e = GetProcAddress(HelperDLL->hModule, "WSHOpenSocket");
157     if (!e) return ERROR_BAD_PROVIDER;
158         ((PVOID) HelperDLL->EntryTable.lpWSHOpenSocket) = e;
159
160     e = GetProcAddress(HelperDLL->hModule, "WSHOpenSocket2");
161     if (!e) return ERROR_BAD_PROVIDER;
162         ((PVOID) HelperDLL->EntryTable.lpWSHOpenSocket2) = e;
163
164     e = GetProcAddress(HelperDLL->hModule, "WSHSetSocketInformation");
165     if (!e) return ERROR_BAD_PROVIDER;
166         ((PVOID) HelperDLL->EntryTable.lpWSHSetSocketInformation) = e;
167
168     e = GetProcAddress(HelperDLL->hModule, "WSHStringToAddress");
169     if (!e) return ERROR_BAD_PROVIDER;
170         ((PVOID) HelperDLL->EntryTable.lpWSHStringToAddress) = e;
171
172     return NO_ERROR;
173 }
174
175
176 INT LoadHelperDLL(
177     PWSHELPER_DLL HelperDLL)
178 {
179     INT Status = NO_ERROR;
180
181     AFD_DbgPrint(MAX_TRACE, ("Loading helper dll at (0x%X).\n", HelperDLL));
182
183     if (!HelperDLL->hModule) {
184         /* DLL is not loaded so load it now */
185         HelperDLL->hModule = LoadLibrary(HelperDLL->LibraryName);
186
187         AFD_DbgPrint(MAX_TRACE, ("hModule is (0x%X).\n", HelperDLL->hModule));
188
189         if (HelperDLL->hModule) {
190             Status = GetHelperDLLEntries(HelperDLL);
191         } else
192             Status = ERROR_DLL_NOT_FOUND;
193     } else
194         Status = NO_ERROR;
195
196     AFD_DbgPrint(MAX_TRACE, ("Status (%d).\n", Status));
197
198     return Status;
199 }
200
201
202 INT UnloadHelperDLL(
203     PWSHELPER_DLL HelperDLL)
204 {
205     INT Status = NO_ERROR;
206
207     AFD_DbgPrint(MAX_TRACE, ("HelperDLL (0x%X) hModule (0x%X).\n", HelperDLL, HelperDLL->hModule));
208
209     if (HelperDLL->hModule) {
210         if (!FreeLibrary(HelperDLL->hModule)) {
211             Status = GetLastError();
212         }
213         HelperDLL->hModule = NULL;
214     }
215
216     return Status;
217 }
218
219
220 VOID CreateHelperDLLDatabase(VOID)
221 {
222     PWSHELPER_DLL HelperDLL;
223
224     InitializeCriticalSection(&HelperDLLDatabaseLock);
225
226     InitializeListHead(&HelperDLLDatabaseListHead);
227
228     /* FIXME: Read helper DLL configuration from registry */
229     HelperDLL = CreateHelperDLL(L"wshtcpip.dll");
230     if (!HelperDLL)
231         return;
232
233     HelperDLL->Mapping = HeapAlloc(
234       GlobalHeap,
235       0,
236       3 * sizeof(WINSOCK_MAPPING) + 3 * sizeof(DWORD));
237     if (!HelperDLL->Mapping)
238         return;
239
240     HelperDLL->Mapping->Rows    = 3;
241     HelperDLL->Mapping->Columns = 3;
242
243     HelperDLL->Mapping->Mapping[0].AddressFamily = AF_INET;
244     HelperDLL->Mapping->Mapping[0].SocketType    = SOCK_STREAM;
245     HelperDLL->Mapping->Mapping[0].Protocol      = IPPROTO_TCP;
246
247     HelperDLL->Mapping->Mapping[1].AddressFamily = AF_INET;
248     HelperDLL->Mapping->Mapping[1].SocketType    = SOCK_DGRAM;
249     HelperDLL->Mapping->Mapping[1].Protocol      = IPPROTO_UDP;
250
251     HelperDLL->Mapping->Mapping[2].AddressFamily = AF_INET;
252     HelperDLL->Mapping->Mapping[2].SocketType    = SOCK_RAW;
253     HelperDLL->Mapping->Mapping[2].Protocol      = 0;
254
255     LoadHelperDLL(HelperDLL);
256 }
257
258
259 VOID DestroyHelperDLLDatabase(VOID)
260 {
261     PLIST_ENTRY CurrentEntry;
262     PLIST_ENTRY NextEntry;
263     PWSHELPER_DLL HelperDLL;
264
265     CurrentEntry = HelperDLLDatabaseListHead.Flink;
266     while (CurrentEntry != &HelperDLLDatabaseListHead) {
267         NextEntry = CurrentEntry->Flink;
268
269               HelperDLL = CONTAINING_RECORD(CurrentEntry,
270                                       WSHELPER_DLL,
271                                       ListEntry);
272
273         DestroyHelperDLL(HelperDLL);
274
275         CurrentEntry = NextEntry;
276     }
277
278     DeleteCriticalSection(&HelperDLLDatabaseLock);
279 }
280
281 /* EOF */