branch update for HEAD-2003021201
[reactos.git] / lib / msafd / misc / dllmain.c
1 /*
2  * COPYRIGHT:   See COPYING in the top level directory
3  * PROJECT:     ReactOS Ancillary Function Driver DLL
4  * FILE:        misc/dllmain.c
5  * PURPOSE:     DLL entry point
6  * PROGRAMMERS: Casper S. Hornstrup (chorns@users.sourceforge.net)
7  * REVISIONS:
8  *   CSH 01/09-2000 Created
9  */
10 #include <string.h>
11 #include <msafd.h>
12 #include <helpers.h>
13
14 #ifdef DBG
15
16 /* See debug.h for debug/trace constants */
17 DWORD DebugTraceLevel = MIN_TRACE;
18 //DWORD DebugTraceLevel = DEBUG_ULTRA;
19
20 #endif /* DBG */
21
22 /* To make the linker happy */
23 VOID STDCALL KeBugCheck (ULONG  BugCheckCode) {}
24
25
26 HANDLE GlobalHeap;
27 WSPUPCALLTABLE Upcalls;
28 LPWPUCOMPLETEOVERLAPPEDREQUEST lpWPUCompleteOverlappedRequest;
29 CRITICAL_SECTION InitCriticalSection;
30 DWORD StartupCount = 0;
31 HANDLE CommandChannel;
32
33
34 NTSTATUS OpenSocket(
35   SOCKET *Socket,
36   INT AddressFamily,
37   INT SocketType,
38   INT Protocol,
39   PVOID HelperContext,
40   DWORD NotificationEvents,
41   PUNICODE_STRING TdiDeviceName)
42 /*
43  * FUNCTION: Opens a socket
44  * ARGUMENTS:
45  *     Socket             = Address of buffer to place socket descriptor
46  *     AddressFamily      = Address family
47  *     SocketType         = Type of socket
48  *     Protocol           = Protocol type
49  *     HelperContext      = Pointer to context information for helper DLL
50  *     NotificationEvents = Events for which helper DLL is to be notified
51  *     TdiDeviceName      = Pointer to name of TDI device to use
52  * RETURNS:
53  *     Status of operation
54  */
55 {
56   OBJECT_ATTRIBUTES ObjectAttributes;
57   PAFD_SOCKET_INFORMATION SocketInfo;
58   PFILE_FULL_EA_INFORMATION EaInfo;
59   UNICODE_STRING DeviceName;
60   IO_STATUS_BLOCK Iosb;
61   HANDLE FileHandle;
62   NTSTATUS Status;
63   ULONG EaLength;
64   ULONG EaShort;
65
66   AFD_DbgPrint(MAX_TRACE, ("Socket (0x%X)  TdiDeviceName (%wZ)\n",
67     Socket, TdiDeviceName));
68
69   AFD_DbgPrint(MAX_TRACE, ("Socket2 (0x%X)  TdiDeviceName (%S)\n",
70     Socket, TdiDeviceName->Buffer));
71
72   EaShort = sizeof(FILE_FULL_EA_INFORMATION) +
73     AFD_SOCKET_LENGTH +
74     sizeof(AFD_SOCKET_INFORMATION);
75
76   EaLength = EaShort + TdiDeviceName->Length + sizeof(WCHAR);
77
78   EaInfo = (PFILE_FULL_EA_INFORMATION)HeapAlloc(GlobalHeap, 0, EaLength);
79   if (!EaInfo) {
80     return STATUS_INSUFFICIENT_RESOURCES;
81   }
82
83   RtlZeroMemory(EaInfo, EaLength);
84   EaInfo->EaNameLength = AFD_SOCKET_LENGTH;
85   RtlCopyMemory(EaInfo->EaName,
86     AfdSocket,
87     AFD_SOCKET_LENGTH);
88   EaInfo->EaValueLength = sizeof(AFD_SOCKET_INFORMATION);
89
90   SocketInfo = (PAFD_SOCKET_INFORMATION)((ULONG_PTR)EaInfo->EaName + AFD_SOCKET_LENGTH);
91   SocketInfo->CommandChannel     = FALSE;
92   SocketInfo->AddressFamily      = AddressFamily;
93   SocketInfo->SocketType         = SocketType;
94   SocketInfo->Protocol           = Protocol;
95   SocketInfo->HelperContext      = HelperContext;
96   SocketInfo->NotificationEvents = NotificationEvents;
97   /* Zeroed above so initialized to a wildcard address if a raw socket */
98   SocketInfo->Name.sa_family     = AddressFamily;
99
100   /* Store TDI device name last in buffer */
101   SocketInfo->TdiDeviceName.Buffer = (PWCHAR)(EaInfo + EaShort);
102   SocketInfo->TdiDeviceName.MaximumLength = TdiDeviceName->Length + sizeof(WCHAR);
103   RtlCopyUnicodeString(&SocketInfo->TdiDeviceName, TdiDeviceName);
104
105   AFD_DbgPrint(MAX_TRACE, ("EaInfo at (0x%X)  EaLength is (%d).\n", (UINT)EaInfo, (INT)EaLength));
106
107   RtlInitUnicodeStringFromLiteral(&DeviceName, L"\\Device\\Afd");
108         InitializeObjectAttributes(
109     &ObjectAttributes,
110     &DeviceName,
111     0,
112     NULL,
113     NULL);
114
115   Status = NtCreateFile(
116     &FileHandle,
117     FILE_GENERIC_READ | FILE_GENERIC_WRITE,
118     &ObjectAttributes,
119     &Iosb,
120     NULL,
121                 0,
122                 0,
123                 FILE_OPEN,
124                 FILE_SYNCHRONOUS_IO_ALERT,
125     EaInfo,
126     EaLength);
127
128   HeapFree(GlobalHeap, 0, EaInfo);
129
130   if (!NT_SUCCESS(Status)) {
131     AFD_DbgPrint(MIN_TRACE, ("Error opening device (Status 0x%X).\n",
132       (UINT)Status));
133     return STATUS_INSUFFICIENT_RESOURCES;
134   }
135
136   *Socket = (SOCKET)FileHandle;
137
138   return STATUS_SUCCESS;
139 }
140
141
142 SOCKET
143 WSPAPI
144 WSPSocket(
145   IN  INT af,
146   IN  INT type,
147   IN  INT protocol,
148   IN  LPWSAPROTOCOL_INFOW lpProtocolInfo,
149   IN  GROUP g,
150   IN  DWORD dwFlags,
151   OUT LPINT lpErrno)
152 /*
153  * FUNCTION: Creates a new socket
154  * ARGUMENTS:
155  *     af             = Address family
156  *     type           = Socket type
157  *     protocol       = Protocol type
158  *     lpProtocolInfo = Pointer to protocol information
159  *     g              = Reserved
160  *     dwFlags        = Socket flags
161  *     lpErrno        = Address of buffer for error information
162  * RETURNS:
163  *     Created socket, or INVALID_SOCKET if it could not be created
164  */
165 {
166   WSAPROTOCOL_INFOW ProtocolInfo;
167   UNICODE_STRING TdiDeviceName;
168   DWORD NotificationEvents;
169   PWSHELPER_DLL HelperDLL;
170   PVOID HelperContext;
171   INT AddressFamily;
172   NTSTATUS NtStatus;
173   INT SocketType;
174   SOCKET Socket2;
175   SOCKET Socket;
176   INT Protocol;
177   INT Status;
178
179   AFD_DbgPrint(MAX_TRACE, ("af (%d)  type (%d)  protocol (%d).\n",
180     af, type, protocol));
181
182   if (!lpProtocolInfo) {
183     lpProtocolInfo = &ProtocolInfo;
184     ZeroMemory(&ProtocolInfo, sizeof(WSAPROTOCOL_INFOW));
185
186     ProtocolInfo.iAddressFamily = af;
187     ProtocolInfo.iSocketType    = type;
188     ProtocolInfo.iProtocol      = protocol;
189   }
190
191   HelperDLL = LocateHelperDLL(lpProtocolInfo);
192   if (!HelperDLL) {
193     *lpErrno = WSAEAFNOSUPPORT;
194     return INVALID_SOCKET;
195   }
196
197   AddressFamily = lpProtocolInfo->iAddressFamily;
198   SocketType    = lpProtocolInfo->iSocketType;
199   Protocol      = lpProtocolInfo->iProtocol;
200
201   Status = HelperDLL->EntryTable.lpWSHOpenSocket2(
202     &AddressFamily,
203     &SocketType,
204     &Protocol,
205     0,
206     0,
207     &TdiDeviceName,
208     &HelperContext,
209     &NotificationEvents);
210   if (Status != NO_ERROR) {
211     AFD_DbgPrint(MAX_TRACE, ("WinSock Helper DLL failed (0x%X).\n", Status));
212     *lpErrno = Status;
213     return INVALID_SOCKET;
214   }
215
216   NtStatus = OpenSocket(&Socket,
217     AddressFamily,
218     SocketType,
219     Protocol,
220     HelperContext,
221     NotificationEvents,
222     &TdiDeviceName);
223
224   RtlFreeUnicodeString(&TdiDeviceName);
225   if (!NT_SUCCESS(NtStatus)) {
226     *lpErrno = RtlNtStatusToDosError(Status);
227     return INVALID_SOCKET;
228   }
229
230   /* FIXME: Assumes catalog entry id to be 1 */
231   Socket2 = Upcalls.lpWPUModifyIFSHandle(1, Socket, lpErrno);
232
233   if (Socket2 == INVALID_SOCKET) {
234     /* FIXME: Cleanup */
235     AFD_DbgPrint(MIN_TRACE, ("FIXME: Cleanup.\n"));
236     return INVALID_SOCKET;
237   }
238
239   *lpErrno = NO_ERROR;
240
241   AFD_DbgPrint(MID_TRACE, ("Returning socket descriptor (0x%X).\n", Socket2));
242
243   return Socket2;
244 }
245
246
247 INT
248 WSPAPI
249 WSPCloseSocket(
250   IN  SOCKET s,
251   OUT   LPINT lpErrno)
252 /*
253  * FUNCTION: Closes an open socket
254  * ARGUMENTS:
255  *     s       = Socket descriptor
256  *     lpErrno = Address of buffer for error information
257  * RETURNS:
258  *     NO_ERROR, or SOCKET_ERROR if the socket could not be closed
259  */
260 {
261   NTSTATUS Status;
262
263   AFD_DbgPrint(MAX_TRACE, ("s (0x%X).\n", s));
264
265   Status = NtClose((HANDLE)s);
266
267   if (NT_SUCCESS(Status)) {
268     *lpErrno = NO_ERROR;
269     return NO_ERROR;
270   }
271
272   *lpErrno = WSAENOTSOCK;
273   return SOCKET_ERROR;
274 }
275
276
277 INT
278 WSPAPI
279 WSPBind(
280   IN  SOCKET s,
281   IN  CONST LPSOCKADDR name, 
282   IN  INT namelen, 
283   OUT LPINT lpErrno)
284 /*
285  * FUNCTION: Associates a local address with a socket
286  * ARGUMENTS:
287  *     s       = Socket descriptor
288  *     name    = Pointer to local address
289  *     namelen = Length of name
290  *     lpErrno = Address of buffer for error information
291  * RETURNS:
292  *     0, or SOCKET_ERROR if the socket could not be bound
293  */
294 {
295   FILE_REQUEST_BIND Request;
296   FILE_REPLY_BIND Reply;
297   IO_STATUS_BLOCK Iosb;
298   NTSTATUS Status;
299
300   AFD_DbgPrint(MAX_TRACE, ("s (0x%X)  name (0x%X)  namelen (%d).\n", s, name, namelen));
301
302   RtlCopyMemory(&Request.Name, name, sizeof(SOCKADDR));
303
304   Status = NtDeviceIoControlFile(
305     (HANDLE)s,
306     NULL,
307                 NULL,
308                 NULL,
309                 &Iosb,
310                 IOCTL_AFD_BIND,
311                 &Request,
312                 sizeof(FILE_REQUEST_BIND),
313                 &Reply,
314                 sizeof(FILE_REPLY_BIND));
315   if (Status == STATUS_PENDING) {
316     AFD_DbgPrint(MAX_TRACE, ("Waiting on transport.\n"));
317     /* FIXME: Wait only for blocking sockets */
318                 Status = NtWaitForSingleObject((HANDLE)s, FALSE, NULL);
319   }
320
321   if (!NT_SUCCESS(Status)) {
322           *lpErrno = Reply.Status;
323     return SOCKET_ERROR;
324         }
325
326   return 0;
327 }
328
329 INT
330 WSPAPI
331 WSPListen(
332     IN  SOCKET s,
333     IN  INT backlog,
334     OUT LPINT lpErrno)
335 /*
336  * FUNCTION: Listens for incoming connections
337  * ARGUMENTS:
338  *     s       = Socket descriptor
339  *     backlog = Maximum number of pending connection requests
340  *     lpErrno = Address of buffer for error information
341  * RETURNS:
342  *     0, or SOCKET_ERROR if the socket could not be bound
343  */
344 {
345   FILE_REQUEST_LISTEN Request;
346   FILE_REPLY_LISTEN Reply;
347   IO_STATUS_BLOCK Iosb;
348   NTSTATUS Status;
349
350   AFD_DbgPrint(MAX_TRACE, ("s (0x%X)  backlog (%d).\n", s, backlog));
351
352   Request.Backlog = backlog;
353
354   Status = NtDeviceIoControlFile(
355     (HANDLE)s,
356     NULL,
357                 NULL,
358                 NULL,
359                 &Iosb,
360                 IOCTL_AFD_LISTEN,
361                 &Request,
362                 sizeof(FILE_REQUEST_LISTEN),
363                 &Reply,
364                 sizeof(FILE_REPLY_LISTEN));
365   if (Status == STATUS_PENDING) {
366     AFD_DbgPrint(MAX_TRACE, ("Waiting on transport.\n"));
367     /* FIXME: Wait only for blocking sockets */
368                 Status = NtWaitForSingleObject((HANDLE)s, FALSE, NULL);
369   }
370
371   if (!NT_SUCCESS(Status)) {
372           *lpErrno = Reply.Status;
373     return SOCKET_ERROR;
374         }
375
376   return 0;
377 }
378
379
380 INT
381 WSPAPI
382 WSPSelect(
383   IN      INT nfds,
384   IN OUT  LPFD_SET readfds,
385   IN OUT  LPFD_SET writefds,
386   IN OUT  LPFD_SET exceptfds,
387   IN      CONST LPTIMEVAL timeout,
388   OUT     LPINT lpErrno)
389 /*
390  * FUNCTION: Returns status of one or more sockets
391  * ARGUMENTS:
392  *     nfds      = Always ignored
393  *     readfds   = Pointer to socket set to be checked for readability (optional)
394  *     writefds  = Pointer to socket set to be checked for writability (optional)
395  *     exceptfds = Pointer to socket set to be checked for errors (optional)
396  *     timeout   = Pointer to a TIMEVAL structure indicating maximum wait time
397  *                 (NULL means wait forever)
398  *     lpErrno   = Address of buffer for error information
399  * RETURNS:
400  *     Number of ready socket descriptors, or SOCKET_ERROR if an error ocurred
401  */
402 {
403   PFILE_REQUEST_SELECT Request;
404   FILE_REPLY_SELECT Reply;
405   IO_STATUS_BLOCK Iosb;
406   NTSTATUS Status;
407   DWORD Size;
408   DWORD ReadSize;
409   DWORD WriteSize;
410   DWORD ExceptSize;
411   PVOID Current;
412
413   AFD_DbgPrint(MAX_TRACE, ("readfds (0x%X)  writefds (0x%X)  exceptfds (0x%X).\n",
414         readfds, writefds, exceptfds));
415 #if 0
416   /* FIXME: For now, all reads are timed out immediately */
417   if (readfds != NULL) {
418     AFD_DbgPrint(MID_TRACE, ("Timing out read query.\n"));
419     *lpErrno = WSAETIMEDOUT;
420     return SOCKET_ERROR;
421   }
422
423   /* FIXME: For now, always allow write */
424   if (writefds != NULL) {
425     AFD_DbgPrint(MID_TRACE, ("Setting one socket writeable.\n"));
426     *lpErrno = NO_ERROR;
427     return 1;
428   }
429 #endif
430
431   ReadSize = 0;
432   if ((readfds != NULL) && (readfds->fd_count > 0)) {
433     ReadSize = (readfds->fd_count * sizeof(SOCKET)) + sizeof(UINT);
434   }
435
436   WriteSize = 0;
437   if ((writefds != NULL) && (writefds->fd_count > 0)) {
438     WriteSize = (writefds->fd_count * sizeof(SOCKET)) + sizeof(UINT);
439   }
440
441   ExceptSize = 0;
442   if ((exceptfds != NULL) && (exceptfds->fd_count > 0)) {
443     ExceptSize = (exceptfds->fd_count * sizeof(SOCKET)) + sizeof(UINT);
444   }
445
446   Size = ReadSize + WriteSize + ExceptSize;
447
448   Request = (PFILE_REQUEST_SELECT)HeapAlloc(
449     GlobalHeap, 0, sizeof(FILE_REQUEST_SELECT) + Size);
450   if (!Request) {
451     *lpErrno = WSAENOBUFS;
452     return SOCKET_ERROR;
453   }
454
455   /* Put FD SETs after request structure */
456   Current = (Request + 1);
457
458   if (ReadSize > 0) {
459     Request->ReadFDSet = (LPFD_SET)Current;
460     Current += ReadSize;
461     RtlCopyMemory(Request->ReadFDSet, readfds, ReadSize);
462   } else {
463     Request->ReadFDSet = NULL;
464   }
465
466   if (WriteSize > 0) {
467     Request->WriteFDSet = (LPFD_SET)Current;
468     Current += WriteSize;
469     RtlCopyMemory(Request->WriteFDSet, writefds, WriteSize);
470   } else {
471     Request->WriteFDSet = NULL;
472   }
473
474   if (ExceptSize > 0) {
475     Request->ExceptFDSet = (LPFD_SET)Current;
476     RtlCopyMemory(Request->ExceptFDSet, exceptfds, ExceptSize);
477   } else {
478     Request->ExceptFDSet = NULL;
479   }
480
481   AFD_DbgPrint(MAX_TRACE, ("R1 (0x%X)  W1 (0x%X).\n", Request->ReadFDSet, Request->WriteFDSet));
482
483   Status = NtDeviceIoControlFile(
484     CommandChannel,
485     NULL,
486                 NULL,
487                 NULL,   
488                 &Iosb,
489                 IOCTL_AFD_SELECT,
490                 Request,
491                 sizeof(FILE_REQUEST_SELECT) + Size,
492                 &Reply,
493                 sizeof(FILE_REPLY_SELECT));
494
495   HeapFree(GlobalHeap, 0, Request);
496
497         if (Status == STATUS_PENDING) {
498     AFD_DbgPrint(MAX_TRACE, ("Waiting on transport.\n"));
499     /* FIXME: Wait only for blocking sockets */
500                 Status = NtWaitForSingleObject(CommandChannel, FALSE, NULL);
501   }
502
503   if (!NT_SUCCESS(Status)) {
504     AFD_DbgPrint(MAX_TRACE, ("Status (0x%X).\n", Status));
505                 *lpErrno = WSAENOBUFS;
506     return SOCKET_ERROR;
507         }
508
509   AFD_DbgPrint(MAX_TRACE, ("Select successful. Status (0x%X)  Count (0x%X).\n",
510     Reply.Status, Reply.SocketCount));
511
512   *lpErrno = Reply.Status;
513
514   return Reply.SocketCount;
515 }
516
517 SOCKET
518 WSPAPI
519 WSPAccept(
520   IN      SOCKET s,
521   OUT     LPSOCKADDR addr,
522   IN OUT  LPINT addrlen,
523   IN      LPCONDITIONPROC lpfnCondition,
524   IN      DWORD dwCallbackData,
525   OUT     LPINT lpErrno)
526 {
527   FILE_REQUEST_ACCEPT Request;
528   FILE_REPLY_ACCEPT Reply;
529   IO_STATUS_BLOCK Iosb;
530   NTSTATUS Status;
531
532   AFD_DbgPrint(MAX_TRACE, ("s (0x%X).\n", s));
533
534   Request.addr = addr;
535   Request.addrlen = *addrlen;
536   Request.lpfnCondition = lpfnCondition;
537   Request.dwCallbackData = dwCallbackData;
538
539   Status = NtDeviceIoControlFile(
540     (HANDLE)s,
541     NULL,
542                 NULL,
543                 NULL,
544                 &Iosb,
545                 IOCTL_AFD_ACCEPT,
546                 &Request,
547                 sizeof(FILE_REQUEST_ACCEPT),
548                 &Reply,
549                 sizeof(FILE_REPLY_ACCEPT));
550   if (Status == STATUS_PENDING) {
551     AFD_DbgPrint(MAX_TRACE, ("Waiting on transport.\n"));
552     /* FIXME: Wait only for blocking sockets */
553                 Status = NtWaitForSingleObject((HANDLE)s, FALSE, NULL);
554   }
555
556   if (!NT_SUCCESS(Status)) {
557           *lpErrno = Reply.Status;
558     return INVALID_SOCKET;
559         }
560
561   *addrlen = Reply.addrlen;
562
563   return Reply.Socket;
564 }
565
566
567 INT
568 WSPAPI
569 WSPConnect(
570   IN  SOCKET s,
571   IN  CONST LPSOCKADDR name,
572   IN  INT namelen,
573   IN  LPWSABUF lpCallerData,
574   OUT LPWSABUF lpCalleeData,
575   IN  LPQOS lpSQOS,
576   IN  LPQOS lpGQOS,
577   OUT LPINT lpErrno)
578 {
579   FILE_REQUEST_CONNECT Request;
580   FILE_REPLY_CONNECT Reply;
581   IO_STATUS_BLOCK Iosb;
582   NTSTATUS Status;
583
584   AFD_DbgPrint(MAX_TRACE, ("s (0x%X).\n", s));
585
586   Request.name = name;
587   Request.namelen = namelen;
588   Request.lpCallerData = lpCallerData;
589   Request.lpCalleeData = lpCalleeData;
590   Request.lpSQOS = lpSQOS;
591   Request.lpGQOS = lpGQOS;
592
593   Status = NtDeviceIoControlFile(
594     (HANDLE)s,
595     NULL,
596                 NULL,
597                 NULL,
598                 &Iosb,
599                 IOCTL_AFD_CONNECT,
600                 &Request,
601                 sizeof(FILE_REQUEST_CONNECT),
602                 &Reply,
603                 sizeof(FILE_REPLY_CONNECT));
604   if (Status == STATUS_PENDING) {
605     AFD_DbgPrint(MAX_TRACE, ("Waiting on transport.\n"));
606     /* FIXME: Wait only for blocking sockets */
607                 Status = NtWaitForSingleObject((HANDLE)s, FALSE, NULL);
608   }
609
610   if (!NT_SUCCESS(Status)) {
611           *lpErrno = Reply.Status;
612     return INVALID_SOCKET;
613         }
614
615   return 0;
616 }
617
618
619 NTSTATUS OpenCommandChannel(
620   VOID)
621 /*
622  * FUNCTION: Opens a command channel to afd.sys
623  * ARGUMENTS:
624  *     None
625  * RETURNS:
626  *     Status of operation
627  */
628 {
629   OBJECT_ATTRIBUTES ObjectAttributes;
630   PAFD_SOCKET_INFORMATION SocketInfo;
631   PFILE_FULL_EA_INFORMATION EaInfo;
632   UNICODE_STRING DeviceName;
633   IO_STATUS_BLOCK Iosb;
634   HANDLE FileHandle;
635   NTSTATUS Status;
636   ULONG EaLength;
637   ULONG EaShort;
638
639   AFD_DbgPrint(MAX_TRACE, ("Called\n"));
640
641   EaShort = sizeof(FILE_FULL_EA_INFORMATION) +
642     AFD_SOCKET_LENGTH +
643     sizeof(AFD_SOCKET_INFORMATION);
644
645   EaLength = EaShort;
646
647   EaInfo = (PFILE_FULL_EA_INFORMATION)HeapAlloc(GlobalHeap, 0, EaLength);
648   if (!EaInfo) {
649     return STATUS_INSUFFICIENT_RESOURCES;
650   }
651
652   RtlZeroMemory(EaInfo, EaLength);
653   EaInfo->EaNameLength = AFD_SOCKET_LENGTH;
654   RtlCopyMemory(EaInfo->EaName,
655     AfdSocket,
656     AFD_SOCKET_LENGTH);
657   EaInfo->EaValueLength = sizeof(AFD_SOCKET_INFORMATION);
658
659   SocketInfo = (PAFD_SOCKET_INFORMATION)((ULONG_PTR)EaInfo->EaName + AFD_SOCKET_LENGTH);
660   SocketInfo->CommandChannel = TRUE;
661
662   RtlInitUnicodeStringFromLiteral(&DeviceName, L"\\Device\\Afd");
663         InitializeObjectAttributes(
664     &ObjectAttributes,
665     &DeviceName,
666     0,
667     NULL,
668     NULL);
669
670   Status = NtCreateFile(
671     &FileHandle,
672     FILE_GENERIC_READ | FILE_GENERIC_WRITE,
673     &ObjectAttributes,
674     &Iosb,
675     NULL,
676                 0,
677                 0,
678                 FILE_OPEN,
679                 FILE_SYNCHRONOUS_IO_ALERT,
680     EaInfo,
681     EaLength);
682
683   if (!NT_SUCCESS(Status)) {
684     AFD_DbgPrint(MIN_TRACE, ("Error opening device (Status 0x%X).\n",
685       (UINT)Status));
686     return Status;
687   }
688
689   CommandChannel = FileHandle;
690
691   return STATUS_SUCCESS;
692 }
693
694
695 NTSTATUS CloseCommandChannel(
696   VOID)
697 /*
698  * FUNCTION: Closes command channel to afd.sys
699  * ARGUMENTS:
700  *     None
701  * RETURNS:
702  *     Status of operation
703  */
704 {
705   AFD_DbgPrint(MAX_TRACE, ("Called.\n"));
706
707   return NtClose(CommandChannel);
708 }
709
710
711 INT
712 WSPAPI
713 WSPStartup(
714   IN  WORD wVersionRequested,
715   OUT LPWSPDATA lpWSPData,
716   IN  LPWSAPROTOCOL_INFOW lpProtocolInfo,
717   IN  WSPUPCALLTABLE UpcallTable,
718   OUT LPWSPPROC_TABLE lpProcTable)
719 /*
720  * FUNCTION: Initialize service provider for a client
721  * ARGUMENTS:
722  *     wVersionRequested = Highest WinSock SPI version that the caller can use
723  *     lpWSPData         = Address of WSPDATA structure to initialize
724  *     lpProtocolInfo    = Pointer to structure that defines the desired protocol
725  *     UpcallTable       = Pointer to upcall table of the WinSock DLL
726  *     lpProcTable       = Address of procedure table to initialize
727  * RETURNS:
728  *     Status of operation
729  */
730 {
731   HMODULE hWS2_32;
732   INT Status;
733
734   AFD_DbgPrint(MAX_TRACE, ("wVersionRequested (0x%X) \n", wVersionRequested));
735
736   EnterCriticalSection(&InitCriticalSection);
737
738   Upcalls = UpcallTable;
739
740   if (StartupCount == 0) {
741     /* First time called */
742
743     Status = OpenCommandChannel();
744     if (NT_SUCCESS(Status)) {
745       hWS2_32 = GetModuleHandle(L"ws2_32.dll");
746       if (hWS2_32 != NULL) {
747         lpWPUCompleteOverlappedRequest = (LPWPUCOMPLETEOVERLAPPEDREQUEST)
748           GetProcAddress(hWS2_32, "WPUCompleteOverlappedRequest");
749         if (lpWPUCompleteOverlappedRequest != NULL) {
750           Status = NO_ERROR;
751           StartupCount++;
752         }
753       } else {
754         AFD_DbgPrint(MIN_TRACE, ("GetModuleHandle() failed for ws2_32.dll\n"));
755       }
756     } else {
757       AFD_DbgPrint(MIN_TRACE, ("Cannot open afd.sys\n"));
758     }
759   } else {
760     Status = NO_ERROR;
761     StartupCount++;
762   }
763
764   LeaveCriticalSection(&InitCriticalSection);
765
766   if (Status == NO_ERROR) {
767     lpProcTable->lpWSPAccept = WSPAccept;
768     lpProcTable->lpWSPAddressToString = WSPAddressToString;
769     lpProcTable->lpWSPAsyncSelect = WSPAsyncSelect;
770     lpProcTable->lpWSPBind = WSPBind;
771     lpProcTable->lpWSPCancelBlockingCall = WSPCancelBlockingCall;
772     lpProcTable->lpWSPCleanup = WSPCleanup;
773     lpProcTable->lpWSPCloseSocket = WSPCloseSocket;
774     lpProcTable->lpWSPConnect = WSPConnect;
775     lpProcTable->lpWSPDuplicateSocket = WSPDuplicateSocket;
776     lpProcTable->lpWSPEnumNetworkEvents = WSPEnumNetworkEvents;
777     lpProcTable->lpWSPEventSelect = WSPEventSelect;
778     lpProcTable->lpWSPGetOverlappedResult = WSPGetOverlappedResult;
779     lpProcTable->lpWSPGetPeerName = WSPGetPeerName;
780     lpProcTable->lpWSPGetSockName = WSPGetSockName;
781     lpProcTable->lpWSPGetSockOpt = WSPGetSockOpt;
782     lpProcTable->lpWSPGetQOSByName = WSPGetQOSByName;
783     lpProcTable->lpWSPIoctl = WSPIoctl;
784     lpProcTable->lpWSPJoinLeaf = WSPJoinLeaf;
785     lpProcTable->lpWSPListen = WSPListen;
786     lpProcTable->lpWSPRecv = WSPRecv;
787     lpProcTable->lpWSPRecvDisconnect = WSPRecvDisconnect;
788     lpProcTable->lpWSPRecvFrom = WSPRecvFrom;
789     lpProcTable->lpWSPSelect = WSPSelect;
790     lpProcTable->lpWSPSend = WSPSend;
791     lpProcTable->lpWSPSendDisconnect = WSPSendDisconnect;
792     lpProcTable->lpWSPSendTo = WSPSendTo;
793     lpProcTable->lpWSPSetSockOpt = WSPSetSockOpt;
794     lpProcTable->lpWSPShutdown = WSPShutdown;
795     lpProcTable->lpWSPSocket = WSPSocket;
796     lpProcTable->lpWSPStringToAddress = WSPStringToAddress;
797
798     lpWSPData->wVersion     = MAKEWORD(2, 2);
799     lpWSPData->wHighVersion = MAKEWORD(2, 2);
800   }
801
802   AFD_DbgPrint(MAX_TRACE, ("Status (%d).\n", Status));
803
804   return Status;
805 }
806
807
808 INT
809 WSPAPI
810 WSPCleanup(
811   OUT LPINT lpErrno)
812 /*
813  * FUNCTION: Cleans up service provider for a client
814  * ARGUMENTS:
815  *     lpErrno = Address of buffer for error information
816  * RETURNS:
817  *     0 if successful, or SOCKET_ERROR if not
818  */
819 {
820   AFD_DbgPrint(MAX_TRACE, ("\n"));
821
822   EnterCriticalSection(&InitCriticalSection);
823
824   if (StartupCount > 0) {
825     StartupCount--;
826
827     if (StartupCount == 0) {
828       AFD_DbgPrint(MAX_TRACE, ("Cleaning up msafd.dll.\n"));
829
830       CloseCommandChannel();
831     }
832   }
833
834   LeaveCriticalSection(&InitCriticalSection);
835
836   AFD_DbgPrint(MAX_TRACE, ("Leaving.\n"));
837
838   *lpErrno = NO_ERROR;
839
840   return 0;
841 }
842
843
844 BOOL
845 STDCALL
846 DllMain(HANDLE hInstDll,
847         ULONG dwReason,
848         PVOID Reserved)
849 {
850     AFD_DbgPrint(MAX_TRACE, ("DllMain of msafd.dll\n"));
851
852     switch (dwReason) {
853     case DLL_PROCESS_ATTACH:
854         /* Don't need thread attach notifications
855            so disable them to improve performance */
856         DisableThreadLibraryCalls(hInstDll);
857
858         InitializeCriticalSection(&InitCriticalSection);
859
860         GlobalHeap = GetProcessHeap();
861
862         CreateHelperDLLDatabase();
863         break;
864
865     case DLL_THREAD_ATTACH:
866         break;
867
868     case DLL_THREAD_DETACH:
869         break;
870
871     case DLL_PROCESS_DETACH:
872
873         DestroyHelperDLLDatabase();
874
875         DeleteCriticalSection(&InitCriticalSection);
876
877         break;
878     }
879
880     AFD_DbgPrint(MAX_TRACE, ("DllMain of msafd.dll (leaving)\n"));
881
882     return TRUE;
883 }
884
885 /* EOF */