:pserver:cvsanon@mok.lvcm.com:/CVS/ReactOS reactos
[reactos.git] / lib / msafd / misc / dllmain.c
1 /*
2  * COPYRIGHT:   See COPYING in the top level directory
3  * PROJECT:     ReactOS Ancillary Function Driver DLL
4  * FILE:        misc/dllmain.c
5  * PURPOSE:     DLL entry point
6  * PROGRAMMERS: Casper S. Hornstrup (chorns@users.sourceforge.net)
7  * REVISIONS:
8  *   CSH 01/09-2000 Created
9  */
10 #include <msafd.h>
11 #include <helpers.h>
12
13 #ifdef DBG
14
15 /* See debug.h for debug/trace constants */
16 DWORD DebugTraceLevel = MIN_TRACE;
17 //DWORD DebugTraceLevel = DEBUG_ULTRA;
18
19 #endif /* DBG */
20
21 /* To make the linker happy */
22 VOID STDCALL KeBugCheck (ULONG  BugCheckCode) {}
23
24
25 HANDLE GlobalHeap;
26 WSPUPCALLTABLE Upcalls;
27 LPWPUCOMPLETEOVERLAPPEDREQUEST lpWPUCompleteOverlappedRequest;
28 CRITICAL_SECTION InitCriticalSection;
29 DWORD StartupCount = 0;
30 HANDLE CommandChannel;
31
32
33 NTSTATUS OpenSocket(
34   SOCKET *Socket,
35   INT AddressFamily,
36   INT SocketType,
37   INT Protocol,
38   PVOID HelperContext,
39   DWORD NotificationEvents,
40   PUNICODE_STRING TdiDeviceName)
41 /*
42  * FUNCTION: Opens a socket
43  * ARGUMENTS:
44  *     Socket             = Address of buffer to place socket descriptor
45  *     AddressFamily      = Address family
46  *     SocketType         = Type of socket
47  *     Protocol           = Protocol type
48  *     HelperContext      = Pointer to context information for helper DLL
49  *     NotificationEvents = Events for which helper DLL is to be notified
50  *     TdiDeviceName      = Pointer to name of TDI device to use
51  * RETURNS:
52  *     Status of operation
53  */
54 {
55   OBJECT_ATTRIBUTES ObjectAttributes;
56   PAFD_SOCKET_INFORMATION SocketInfo;
57   PFILE_FULL_EA_INFORMATION EaInfo;
58   UNICODE_STRING DeviceName;
59   IO_STATUS_BLOCK Iosb;
60   HANDLE FileHandle;
61   NTSTATUS Status;
62   ULONG EaLength;
63   ULONG EaShort;
64
65   AFD_DbgPrint(MAX_TRACE, ("Socket (0x%X)  TdiDeviceName (%wZ)\n",
66     Socket, TdiDeviceName));
67
68   AFD_DbgPrint(MAX_TRACE, ("Socket2 (0x%X)  TdiDeviceName (%S)\n",
69     Socket, TdiDeviceName->Buffer));
70
71   EaShort = sizeof(FILE_FULL_EA_INFORMATION) +
72     AFD_SOCKET_LENGTH +
73     sizeof(AFD_SOCKET_INFORMATION);
74
75   EaLength = EaShort + TdiDeviceName->Length + sizeof(WCHAR);
76
77   EaInfo = (PFILE_FULL_EA_INFORMATION)HeapAlloc(GlobalHeap, 0, EaLength);
78   if (!EaInfo) {
79     return STATUS_INSUFFICIENT_RESOURCES;
80   }
81
82   RtlZeroMemory(EaInfo, EaLength);
83   EaInfo->EaNameLength = AFD_SOCKET_LENGTH;
84   RtlCopyMemory(EaInfo->EaName,
85     AfdSocket,
86     AFD_SOCKET_LENGTH);
87   EaInfo->EaValueLength = sizeof(AFD_SOCKET_INFORMATION);
88
89   SocketInfo = (PAFD_SOCKET_INFORMATION)((ULONG_PTR)EaInfo->EaName + AFD_SOCKET_LENGTH);
90   SocketInfo->CommandChannel     = FALSE;
91   SocketInfo->AddressFamily      = AddressFamily;
92   SocketInfo->SocketType         = SocketType;
93   SocketInfo->Protocol           = Protocol;
94   SocketInfo->HelperContext      = HelperContext;
95   SocketInfo->NotificationEvents = NotificationEvents;
96   /* Zeroed above so initialized to a wildcard address if a raw socket */
97   SocketInfo->Name.sa_family     = AddressFamily;
98
99   /* Store TDI device name last in buffer */
100   SocketInfo->TdiDeviceName.Buffer = (PWCHAR)(EaInfo + EaShort);
101   SocketInfo->TdiDeviceName.MaximumLength = TdiDeviceName->Length + sizeof(WCHAR);
102   RtlCopyUnicodeString(&SocketInfo->TdiDeviceName, TdiDeviceName);
103
104   AFD_DbgPrint(MAX_TRACE, ("EaInfo at (0x%X)  EaLength is (%d).\n", (UINT)EaInfo, (INT)EaLength));
105
106   RtlInitUnicodeStringFromLiteral(&DeviceName, L"\\Device\\Afd");
107         InitializeObjectAttributes(
108     &ObjectAttributes,
109     &DeviceName,
110     0,
111     NULL,
112     NULL);
113
114   Status = NtCreateFile(
115     &FileHandle,
116     FILE_GENERIC_READ | FILE_GENERIC_WRITE,
117     &ObjectAttributes,
118     &Iosb,
119     NULL,
120                 0,
121                 0,
122                 FILE_OPEN,
123                 FILE_SYNCHRONOUS_IO_ALERT,
124     EaInfo,
125     EaLength);
126
127   HeapFree(GlobalHeap, 0, EaInfo);
128
129   if (!NT_SUCCESS(Status)) {
130     AFD_DbgPrint(MIN_TRACE, ("Error opening device (Status 0x%X).\n",
131       (UINT)Status));
132     return STATUS_INSUFFICIENT_RESOURCES;
133   }
134
135   *Socket = (SOCKET)FileHandle;
136
137   return STATUS_SUCCESS;
138 }
139
140
141 SOCKET
142 WSPAPI
143 WSPSocket(
144   IN  INT af,
145   IN  INT type,
146   IN  INT protocol,
147   IN  LPWSAPROTOCOL_INFOW lpProtocolInfo,
148   IN  GROUP g,
149   IN  DWORD dwFlags,
150   OUT LPINT lpErrno)
151 /*
152  * FUNCTION: Creates a new socket
153  * ARGUMENTS:
154  *     af             = Address family
155  *     type           = Socket type
156  *     protocol       = Protocol type
157  *     lpProtocolInfo = Pointer to protocol information
158  *     g              = Reserved
159  *     dwFlags        = Socket flags
160  *     lpErrno        = Address of buffer for error information
161  * RETURNS:
162  *     Created socket, or INVALID_SOCKET if it could not be created
163  */
164 {
165   WSAPROTOCOL_INFOW ProtocolInfo;
166   UNICODE_STRING TdiDeviceName;
167   DWORD NotificationEvents;
168   PWSHELPER_DLL HelperDLL;
169   PVOID HelperContext;
170   INT AddressFamily;
171   NTSTATUS NtStatus;
172   INT SocketType;
173   SOCKET Socket2;
174   SOCKET Socket;
175   INT Protocol;
176   INT Status;
177
178   AFD_DbgPrint(MAX_TRACE, ("af (%d)  type (%d)  protocol (%d).\n",
179     af, type, protocol));
180
181   if (!lpProtocolInfo) {
182     lpProtocolInfo = &ProtocolInfo;
183     ZeroMemory(&ProtocolInfo, sizeof(WSAPROTOCOL_INFOW));
184
185     ProtocolInfo.iAddressFamily = af;
186     ProtocolInfo.iSocketType    = type;
187     ProtocolInfo.iProtocol      = protocol;
188   }
189
190   HelperDLL = LocateHelperDLL(lpProtocolInfo);
191   if (!HelperDLL) {
192     *lpErrno = WSAEAFNOSUPPORT;
193     return INVALID_SOCKET;
194   }
195
196   AddressFamily = lpProtocolInfo->iAddressFamily;
197   SocketType    = lpProtocolInfo->iSocketType;
198   Protocol      = lpProtocolInfo->iProtocol;
199
200   Status = HelperDLL->EntryTable.lpWSHOpenSocket2(
201     &AddressFamily,
202     &SocketType,
203     &Protocol,
204     0,
205     0,
206     &TdiDeviceName,
207     &HelperContext,
208     &NotificationEvents);
209   if (Status != NO_ERROR) {
210     AFD_DbgPrint(MAX_TRACE, ("WinSock Helper DLL failed (0x%X).\n", Status));
211     *lpErrno = Status;
212     return INVALID_SOCKET;
213   }
214
215   NtStatus = OpenSocket(&Socket,
216     AddressFamily,
217     SocketType,
218     Protocol,
219     HelperContext,
220     NotificationEvents,
221     &TdiDeviceName);
222
223   RtlFreeUnicodeString(&TdiDeviceName);
224   if (!NT_SUCCESS(NtStatus)) {
225     *lpErrno = RtlNtStatusToDosError(Status);
226     return INVALID_SOCKET;
227   }
228
229   /* FIXME: Assumes catalog entry id to be 1 */
230   Socket2 = Upcalls.lpWPUModifyIFSHandle(1, Socket, lpErrno);
231
232   if (Socket2 == INVALID_SOCKET) {
233     /* FIXME: Cleanup */
234     AFD_DbgPrint(MIN_TRACE, ("FIXME: Cleanup.\n"));
235     return INVALID_SOCKET;
236   }
237
238   *lpErrno = NO_ERROR;
239
240   AFD_DbgPrint(MID_TRACE, ("Returning socket descriptor (0x%X).\n", Socket2));
241
242   return Socket2;
243 }
244
245
246 INT
247 WSPAPI
248 WSPCloseSocket(
249   IN  SOCKET s,
250   OUT   LPINT lpErrno)
251 /*
252  * FUNCTION: Closes an open socket
253  * ARGUMENTS:
254  *     s       = Socket descriptor
255  *     lpErrno = Address of buffer for error information
256  * RETURNS:
257  *     NO_ERROR, or SOCKET_ERROR if the socket could not be closed
258  */
259 {
260   NTSTATUS Status;
261
262   AFD_DbgPrint(MAX_TRACE, ("s (0x%X).\n", s));
263
264   Status = NtClose((HANDLE)s);
265
266   if (NT_SUCCESS(Status)) {
267     *lpErrno = NO_ERROR;
268     return NO_ERROR;
269   }
270
271   *lpErrno = WSAENOTSOCK;
272   return SOCKET_ERROR;
273 }
274
275
276 INT
277 WSPAPI
278 WSPBind(
279   IN  SOCKET s,
280   IN  CONST LPSOCKADDR name, 
281   IN  INT namelen, 
282   OUT LPINT lpErrno)
283 /*
284  * FUNCTION: Associates a local address with a socket
285  * ARGUMENTS:
286  *     s       = Socket descriptor
287  *     name    = Pointer to local address
288  *     namelen = Length of name
289  *     lpErrno = Address of buffer for error information
290  * RETURNS:
291  *     0, or SOCKET_ERROR if the socket could not be bound
292  */
293 {
294   FILE_REQUEST_BIND Request;
295   FILE_REPLY_BIND Reply;
296   IO_STATUS_BLOCK Iosb;
297   NTSTATUS Status;
298
299   AFD_DbgPrint(MAX_TRACE, ("s (0x%X)  name (0x%X)  namelen (%d).\n", s, name, namelen));
300
301   RtlCopyMemory(&Request.Name, name, sizeof(SOCKADDR));
302
303   Status = NtDeviceIoControlFile(
304     (HANDLE)s,
305     NULL,
306                 NULL,
307                 NULL,
308                 &Iosb,
309                 IOCTL_AFD_BIND,
310                 &Request,
311                 sizeof(FILE_REQUEST_BIND),
312                 &Reply,
313                 sizeof(FILE_REPLY_BIND));
314   if (Status == STATUS_PENDING) {
315     AFD_DbgPrint(MAX_TRACE, ("Waiting on transport.\n"));
316     /* FIXME: Wait only for blocking sockets */
317                 Status = NtWaitForSingleObject((HANDLE)s, FALSE, NULL);
318   }
319
320   if (!NT_SUCCESS(Status)) {
321           *lpErrno = Reply.Status;
322     return SOCKET_ERROR;
323         }
324
325   return 0;
326 }
327
328 INT
329 WSPAPI
330 WSPListen(
331     IN  SOCKET s,
332     IN  INT backlog,
333     OUT LPINT lpErrno)
334 /*
335  * FUNCTION: Listens for incoming connections
336  * ARGUMENTS:
337  *     s       = Socket descriptor
338  *     backlog = Maximum number of pending connection requests
339  *     lpErrno = Address of buffer for error information
340  * RETURNS:
341  *     0, or SOCKET_ERROR if the socket could not be bound
342  */
343 {
344   FILE_REQUEST_LISTEN Request;
345   FILE_REPLY_LISTEN Reply;
346   IO_STATUS_BLOCK Iosb;
347   NTSTATUS Status;
348
349   AFD_DbgPrint(MAX_TRACE, ("s (0x%X)  backlog (%d).\n", s, backlog));
350
351   Request.Backlog = backlog;
352
353   Status = NtDeviceIoControlFile(
354     (HANDLE)s,
355     NULL,
356                 NULL,
357                 NULL,
358                 &Iosb,
359                 IOCTL_AFD_LISTEN,
360                 &Request,
361                 sizeof(FILE_REQUEST_LISTEN),
362                 &Reply,
363                 sizeof(FILE_REPLY_LISTEN));
364   if (Status == STATUS_PENDING) {
365     AFD_DbgPrint(MAX_TRACE, ("Waiting on transport.\n"));
366     /* FIXME: Wait only for blocking sockets */
367                 Status = NtWaitForSingleObject((HANDLE)s, FALSE, NULL);
368   }
369
370   if (!NT_SUCCESS(Status)) {
371           *lpErrno = Reply.Status;
372     return SOCKET_ERROR;
373         }
374
375   return 0;
376 }
377
378
379 INT
380 WSPAPI
381 WSPSelect(
382   IN      INT nfds,
383   IN OUT  LPFD_SET readfds,
384   IN OUT  LPFD_SET writefds,
385   IN OUT  LPFD_SET exceptfds,
386   IN      CONST LPTIMEVAL timeout,
387   OUT     LPINT lpErrno)
388 /*
389  * FUNCTION: Returns status of one or more sockets
390  * ARGUMENTS:
391  *     nfds      = Always ignored
392  *     readfds   = Pointer to socket set to be checked for readability (optional)
393  *     writefds  = Pointer to socket set to be checked for writability (optional)
394  *     exceptfds = Pointer to socket set to be checked for errors (optional)
395  *     timeout   = Pointer to a TIMEVAL structure indicating maximum wait time
396  *                 (NULL means wait forever)
397  *     lpErrno   = Address of buffer for error information
398  * RETURNS:
399  *     Number of ready socket descriptors, or SOCKET_ERROR if an error ocurred
400  */
401 {
402   PFILE_REQUEST_SELECT Request;
403   FILE_REPLY_SELECT Reply;
404   IO_STATUS_BLOCK Iosb;
405   NTSTATUS Status;
406   DWORD Size;
407   DWORD ReadSize;
408   DWORD WriteSize;
409   DWORD ExceptSize;
410   PVOID Current;
411
412   AFD_DbgPrint(MAX_TRACE, ("readfds (0x%X)  writefds (0x%X)  exceptfds (0x%X).\n",
413         readfds, writefds, exceptfds));
414 #if 0
415   /* FIXME: For now, all reads are timed out immediately */
416   if (readfds != NULL) {
417     AFD_DbgPrint(MID_TRACE, ("Timing out read query.\n"));
418     *lpErrno = WSAETIMEDOUT;
419     return SOCKET_ERROR;
420   }
421
422   /* FIXME: For now, always allow write */
423   if (writefds != NULL) {
424     AFD_DbgPrint(MID_TRACE, ("Setting one socket writeable.\n"));
425     *lpErrno = NO_ERROR;
426     return 1;
427   }
428 #endif
429
430   ReadSize = 0;
431   if ((readfds != NULL) && (readfds->fd_count > 0)) {
432     ReadSize = (readfds->fd_count * sizeof(SOCKET)) + sizeof(UINT);
433   }
434
435   WriteSize = 0;
436   if ((writefds != NULL) && (writefds->fd_count > 0)) {
437     WriteSize = (writefds->fd_count * sizeof(SOCKET)) + sizeof(UINT);
438   }
439
440   ExceptSize = 0;
441   if ((exceptfds != NULL) && (exceptfds->fd_count > 0)) {
442     ExceptSize = (exceptfds->fd_count * sizeof(SOCKET)) + sizeof(UINT);
443   }
444
445   Size = ReadSize + WriteSize + ExceptSize;
446
447   Request = (PFILE_REQUEST_SELECT)HeapAlloc(
448     GlobalHeap, 0, sizeof(FILE_REQUEST_SELECT) + Size);
449   if (!Request) {
450     *lpErrno = WSAENOBUFS;
451     return SOCKET_ERROR;
452   }
453
454   /* Put FD SETs after request structure */
455   Current = (Request + 1);
456
457   if (ReadSize > 0) {
458     Request->ReadFDSet = (LPFD_SET)Current;
459     Current += ReadSize;
460     RtlCopyMemory(Request->ReadFDSet, readfds, ReadSize);
461   } else {
462     Request->ReadFDSet = NULL;
463   }
464
465   if (WriteSize > 0) {
466     Request->WriteFDSet = (LPFD_SET)Current;
467     Current += WriteSize;
468     RtlCopyMemory(Request->WriteFDSet, writefds, WriteSize);
469   } else {
470     Request->WriteFDSet = NULL;
471   }
472
473   if (ExceptSize > 0) {
474     Request->ExceptFDSet = (LPFD_SET)Current;
475     RtlCopyMemory(Request->ExceptFDSet, exceptfds, ExceptSize);
476   } else {
477     Request->ExceptFDSet = NULL;
478   }
479
480   AFD_DbgPrint(MAX_TRACE, ("R1 (0x%X)  W1 (0x%X).\n", Request->ReadFDSet, Request->WriteFDSet));
481
482   Status = NtDeviceIoControlFile(
483     CommandChannel,
484     NULL,
485                 NULL,
486                 NULL,   
487                 &Iosb,
488                 IOCTL_AFD_SELECT,
489                 Request,
490                 sizeof(FILE_REQUEST_SELECT) + Size,
491                 &Reply,
492                 sizeof(FILE_REPLY_SELECT));
493
494   HeapFree(GlobalHeap, 0, Request);
495
496         if (Status == STATUS_PENDING) {
497     AFD_DbgPrint(MAX_TRACE, ("Waiting on transport.\n"));
498     /* FIXME: Wait only for blocking sockets */
499                 Status = NtWaitForSingleObject(CommandChannel, FALSE, NULL);
500   }
501
502   if (!NT_SUCCESS(Status)) {
503     AFD_DbgPrint(MAX_TRACE, ("Status (0x%X).\n", Status));
504                 *lpErrno = WSAENOBUFS;
505     return SOCKET_ERROR;
506         }
507
508   AFD_DbgPrint(MAX_TRACE, ("Select successful. Status (0x%X)  Count (0x%X).\n",
509     Reply.Status, Reply.SocketCount));
510
511   *lpErrno = Reply.Status;
512
513   return Reply.SocketCount;
514 }
515
516 SOCKET
517 WSPAPI
518 WSPAccept(
519   IN      SOCKET s,
520   OUT     LPSOCKADDR addr,
521   IN OUT  LPINT addrlen,
522   IN      LPCONDITIONPROC lpfnCondition,
523   IN      DWORD dwCallbackData,
524   OUT     LPINT lpErrno)
525 {
526   FILE_REQUEST_ACCEPT Request;
527   FILE_REPLY_ACCEPT Reply;
528   IO_STATUS_BLOCK Iosb;
529   NTSTATUS Status;
530
531   AFD_DbgPrint(MAX_TRACE, ("s (0x%X).\n", s));
532
533   Request.addr = addr;
534   Request.addrlen = *addrlen;
535   Request.lpfnCondition = lpfnCondition;
536   Request.dwCallbackData = dwCallbackData;
537
538   Status = NtDeviceIoControlFile(
539     (HANDLE)s,
540     NULL,
541                 NULL,
542                 NULL,
543                 &Iosb,
544                 IOCTL_AFD_ACCEPT,
545                 &Request,
546                 sizeof(FILE_REQUEST_ACCEPT),
547                 &Reply,
548                 sizeof(FILE_REPLY_ACCEPT));
549   if (Status == STATUS_PENDING) {
550     AFD_DbgPrint(MAX_TRACE, ("Waiting on transport.\n"));
551     /* FIXME: Wait only for blocking sockets */
552                 Status = NtWaitForSingleObject((HANDLE)s, FALSE, NULL);
553   }
554
555   if (!NT_SUCCESS(Status)) {
556           *lpErrno = Reply.Status;
557     return INVALID_SOCKET;
558         }
559
560   *addrlen = Reply.addrlen;
561
562   return Reply.Socket;
563 }
564
565
566 INT
567 WSPAPI
568 WSPConnect(
569   IN  SOCKET s,
570   IN  CONST LPSOCKADDR name,
571   IN  INT namelen,
572   IN  LPWSABUF lpCallerData,
573   OUT LPWSABUF lpCalleeData,
574   IN  LPQOS lpSQOS,
575   IN  LPQOS lpGQOS,
576   OUT LPINT lpErrno)
577 {
578   FILE_REQUEST_CONNECT Request;
579   FILE_REPLY_CONNECT Reply;
580   IO_STATUS_BLOCK Iosb;
581   NTSTATUS Status;
582
583   AFD_DbgPrint(MAX_TRACE, ("s (0x%X).\n", s));
584
585   Request.name = name;
586   Request.namelen = namelen;
587   Request.lpCallerData = lpCallerData;
588   Request.lpCalleeData = lpCalleeData;
589   Request.lpSQOS = lpSQOS;
590   Request.lpGQOS = lpGQOS;
591
592   Status = NtDeviceIoControlFile(
593     (HANDLE)s,
594     NULL,
595                 NULL,
596                 NULL,
597                 &Iosb,
598                 IOCTL_AFD_CONNECT,
599                 &Request,
600                 sizeof(FILE_REQUEST_CONNECT),
601                 &Reply,
602                 sizeof(FILE_REPLY_CONNECT));
603   if (Status == STATUS_PENDING) {
604     AFD_DbgPrint(MAX_TRACE, ("Waiting on transport.\n"));
605     /* FIXME: Wait only for blocking sockets */
606                 Status = NtWaitForSingleObject((HANDLE)s, FALSE, NULL);
607   }
608
609   if (!NT_SUCCESS(Status)) {
610           *lpErrno = Reply.Status;
611     return INVALID_SOCKET;
612         }
613
614   return 0;
615 }
616
617
618 NTSTATUS OpenCommandChannel(
619   VOID)
620 /*
621  * FUNCTION: Opens a command channel to afd.sys
622  * ARGUMENTS:
623  *     None
624  * RETURNS:
625  *     Status of operation
626  */
627 {
628   OBJECT_ATTRIBUTES ObjectAttributes;
629   PAFD_SOCKET_INFORMATION SocketInfo;
630   PFILE_FULL_EA_INFORMATION EaInfo;
631   UNICODE_STRING DeviceName;
632   IO_STATUS_BLOCK Iosb;
633   HANDLE FileHandle;
634   NTSTATUS Status;
635   ULONG EaLength;
636   ULONG EaShort;
637
638   AFD_DbgPrint(MAX_TRACE, ("Called\n"));
639
640   EaShort = sizeof(FILE_FULL_EA_INFORMATION) +
641     AFD_SOCKET_LENGTH +
642     sizeof(AFD_SOCKET_INFORMATION);
643
644   EaLength = EaShort;
645
646   EaInfo = (PFILE_FULL_EA_INFORMATION)HeapAlloc(GlobalHeap, 0, EaLength);
647   if (!EaInfo) {
648     return STATUS_INSUFFICIENT_RESOURCES;
649   }
650
651   RtlZeroMemory(EaInfo, EaLength);
652   EaInfo->EaNameLength = AFD_SOCKET_LENGTH;
653   RtlCopyMemory(EaInfo->EaName,
654     AfdSocket,
655     AFD_SOCKET_LENGTH);
656   EaInfo->EaValueLength = sizeof(AFD_SOCKET_INFORMATION);
657
658   SocketInfo = (PAFD_SOCKET_INFORMATION)((ULONG_PTR)EaInfo->EaName + AFD_SOCKET_LENGTH);
659   SocketInfo->CommandChannel = TRUE;
660
661   RtlInitUnicodeStringFromLiteral(&DeviceName, L"\\Device\\Afd");
662         InitializeObjectAttributes(
663     &ObjectAttributes,
664     &DeviceName,
665     0,
666     NULL,
667     NULL);
668
669   Status = NtCreateFile(
670     &FileHandle,
671     FILE_GENERIC_READ | FILE_GENERIC_WRITE,
672     &ObjectAttributes,
673     &Iosb,
674     NULL,
675                 0,
676                 0,
677                 FILE_OPEN,
678                 FILE_SYNCHRONOUS_IO_ALERT,
679     EaInfo,
680     EaLength);
681
682   if (!NT_SUCCESS(Status)) {
683     AFD_DbgPrint(MIN_TRACE, ("Error opening device (Status 0x%X).\n",
684       (UINT)Status));
685     return Status;
686   }
687
688   CommandChannel = FileHandle;
689
690   return STATUS_SUCCESS;
691 }
692
693
694 NTSTATUS CloseCommandChannel(
695   VOID)
696 /*
697  * FUNCTION: Closes command channel to afd.sys
698  * ARGUMENTS:
699  *     None
700  * RETURNS:
701  *     Status of operation
702  */
703 {
704   AFD_DbgPrint(MAX_TRACE, ("Called.\n"));
705
706   return NtClose(CommandChannel);
707 }
708
709
710 INT
711 WSPAPI
712 WSPStartup(
713   IN  WORD wVersionRequested,
714   OUT LPWSPDATA lpWSPData,
715   IN  LPWSAPROTOCOL_INFOW lpProtocolInfo,
716   IN  WSPUPCALLTABLE UpcallTable,
717   OUT LPWSPPROC_TABLE lpProcTable)
718 /*
719  * FUNCTION: Initialize service provider for a client
720  * ARGUMENTS:
721  *     wVersionRequested = Highest WinSock SPI version that the caller can use
722  *     lpWSPData         = Address of WSPDATA structure to initialize
723  *     lpProtocolInfo    = Pointer to structure that defines the desired protocol
724  *     UpcallTable       = Pointer to upcall table of the WinSock DLL
725  *     lpProcTable       = Address of procedure table to initialize
726  * RETURNS:
727  *     Status of operation
728  */
729 {
730   HMODULE hWS2_32;
731   INT Status;
732
733   AFD_DbgPrint(MAX_TRACE, ("wVersionRequested (0x%X) \n", wVersionRequested));
734
735   EnterCriticalSection(&InitCriticalSection);
736
737   Upcalls = UpcallTable;
738
739   if (StartupCount == 0) {
740     /* First time called */
741
742     Status = OpenCommandChannel();
743     if (NT_SUCCESS(Status)) {
744       hWS2_32 = GetModuleHandle(L"ws2_32.dll");
745       if (hWS2_32 != NULL) {
746         lpWPUCompleteOverlappedRequest = (LPWPUCOMPLETEOVERLAPPEDREQUEST)
747           GetProcAddress(hWS2_32, "WPUCompleteOverlappedRequest");
748         if (lpWPUCompleteOverlappedRequest != NULL) {
749           Status = NO_ERROR;
750           StartupCount++;
751         }
752       } else {
753         AFD_DbgPrint(MIN_TRACE, ("GetModuleHandle() failed for ws2_32.dll\n"));
754       }
755     } else {
756       AFD_DbgPrint(MIN_TRACE, ("Cannot open afd.sys\n"));
757     }
758   } else {
759     Status = NO_ERROR;
760     StartupCount++;
761   }
762
763   LeaveCriticalSection(&InitCriticalSection);
764
765   if (Status == NO_ERROR) {
766     lpProcTable->lpWSPAccept = WSPAccept;
767     lpProcTable->lpWSPAddressToString = WSPAddressToString;
768     lpProcTable->lpWSPAsyncSelect = WSPAsyncSelect;
769     lpProcTable->lpWSPBind = WSPBind;
770     lpProcTable->lpWSPCancelBlockingCall = WSPCancelBlockingCall;
771     lpProcTable->lpWSPCleanup = WSPCleanup;
772     lpProcTable->lpWSPCloseSocket = WSPCloseSocket;
773     lpProcTable->lpWSPConnect = WSPConnect;
774     lpProcTable->lpWSPDuplicateSocket = WSPDuplicateSocket;
775     lpProcTable->lpWSPEnumNetworkEvents = WSPEnumNetworkEvents;
776     lpProcTable->lpWSPEventSelect = WSPEventSelect;
777     lpProcTable->lpWSPGetOverlappedResult = WSPGetOverlappedResult;
778     lpProcTable->lpWSPGetPeerName = WSPGetPeerName;
779     lpProcTable->lpWSPGetSockName = WSPGetSockName;
780     lpProcTable->lpWSPGetSockOpt = WSPGetSockOpt;
781     lpProcTable->lpWSPGetQOSByName = WSPGetQOSByName;
782     lpProcTable->lpWSPIoctl = WSPIoctl;
783     lpProcTable->lpWSPJoinLeaf = WSPJoinLeaf;
784     lpProcTable->lpWSPListen = WSPListen;
785     lpProcTable->lpWSPRecv = WSPRecv;
786     lpProcTable->lpWSPRecvDisconnect = WSPRecvDisconnect;
787     lpProcTable->lpWSPRecvFrom = WSPRecvFrom;
788     lpProcTable->lpWSPSelect = WSPSelect;
789     lpProcTable->lpWSPSend = WSPSend;
790     lpProcTable->lpWSPSendDisconnect = WSPSendDisconnect;
791     lpProcTable->lpWSPSendTo = WSPSendTo;
792     lpProcTable->lpWSPSetSockOpt = WSPSetSockOpt;
793     lpProcTable->lpWSPShutdown = WSPShutdown;
794     lpProcTable->lpWSPSocket = WSPSocket;
795     lpProcTable->lpWSPStringToAddress = WSPStringToAddress;
796
797     lpWSPData->wVersion     = MAKEWORD(2, 2);
798     lpWSPData->wHighVersion = MAKEWORD(2, 2);
799   }
800
801   AFD_DbgPrint(MAX_TRACE, ("Status (%d).\n", Status));
802
803   return Status;
804 }
805
806
807 INT
808 WSPAPI
809 WSPCleanup(
810   OUT LPINT lpErrno)
811 /*
812  * FUNCTION: Cleans up service provider for a client
813  * ARGUMENTS:
814  *     lpErrno = Address of buffer for error information
815  * RETURNS:
816  *     0 if successful, or SOCKET_ERROR if not
817  */
818 {
819   AFD_DbgPrint(MAX_TRACE, ("\n"));
820
821   EnterCriticalSection(&InitCriticalSection);
822
823   if (StartupCount > 0) {
824     StartupCount--;
825
826     if (StartupCount == 0) {
827       AFD_DbgPrint(MAX_TRACE, ("Cleaning up msafd.dll.\n"));
828
829       CloseCommandChannel();
830     }
831   }
832
833   LeaveCriticalSection(&InitCriticalSection);
834
835   AFD_DbgPrint(MAX_TRACE, ("Leaving.\n"));
836
837   *lpErrno = NO_ERROR;
838
839   return 0;
840 }
841
842
843 BOOL
844 STDCALL
845 DllMain(HANDLE hInstDll,
846         ULONG dwReason,
847         PVOID Reserved)
848 {
849     AFD_DbgPrint(MAX_TRACE, ("DllMain of msafd.dll\n"));
850
851     switch (dwReason) {
852     case DLL_PROCESS_ATTACH:
853         /* Don't need thread attach notifications
854            so disable them to improve performance */
855         DisableThreadLibraryCalls(hInstDll);
856
857         InitializeCriticalSection(&InitCriticalSection);
858
859         GlobalHeap = GetProcessHeap();
860
861         CreateHelperDLLDatabase();
862         break;
863
864     case DLL_THREAD_ATTACH:
865         break;
866
867     case DLL_THREAD_DETACH:
868         break;
869
870     case DLL_PROCESS_DETACH:
871
872         DestroyHelperDLLDatabase();
873
874         DeleteCriticalSection(&InitCriticalSection);
875
876         break;
877     }
878
879     AFD_DbgPrint(MAX_TRACE, ("DllMain of msafd.dll (leaving)\n"));
880
881     return TRUE;
882 }
883
884 /* EOF */