2 * COPYRIGHT: See COPYING in the top level directory
3 * PROJECT: ReactOS Ancillary Function Driver DLL
5 * PURPOSE: DLL entry point
6 * PROGRAMMERS: Casper S. Hornstrup (chorns@users.sourceforge.net)
8 * CSH 01/09-2000 Created
16 /* See debug.h for debug/trace constants */
17 DWORD DebugTraceLevel = MIN_TRACE;
18 //DWORD DebugTraceLevel = DEBUG_ULTRA;
22 /* To make the linker happy */
23 VOID STDCALL KeBugCheck (ULONG BugCheckCode) {}
27 WSPUPCALLTABLE Upcalls;
28 LPWPUCOMPLETEOVERLAPPEDREQUEST lpWPUCompleteOverlappedRequest;
29 CRITICAL_SECTION InitCriticalSection;
30 DWORD StartupCount = 0;
31 HANDLE CommandChannel;
40 DWORD NotificationEvents,
41 PUNICODE_STRING TdiDeviceName)
43 * FUNCTION: Opens a socket
45 * Socket = Address of buffer to place socket descriptor
46 * AddressFamily = Address family
47 * SocketType = Type of socket
48 * Protocol = Protocol type
49 * HelperContext = Pointer to context information for helper DLL
50 * NotificationEvents = Events for which helper DLL is to be notified
51 * TdiDeviceName = Pointer to name of TDI device to use
56 OBJECT_ATTRIBUTES ObjectAttributes;
57 PAFD_SOCKET_INFORMATION SocketInfo;
58 PFILE_FULL_EA_INFORMATION EaInfo;
59 UNICODE_STRING DeviceName;
66 AFD_DbgPrint(MAX_TRACE, ("Socket (0x%X) TdiDeviceName (%wZ)\n",
67 Socket, TdiDeviceName));
69 AFD_DbgPrint(MAX_TRACE, ("Socket2 (0x%X) TdiDeviceName (%S)\n",
70 Socket, TdiDeviceName->Buffer));
72 EaShort = sizeof(FILE_FULL_EA_INFORMATION) +
74 sizeof(AFD_SOCKET_INFORMATION);
76 EaLength = EaShort + TdiDeviceName->Length + sizeof(WCHAR);
78 EaInfo = (PFILE_FULL_EA_INFORMATION)HeapAlloc(GlobalHeap, 0, EaLength);
80 return STATUS_INSUFFICIENT_RESOURCES;
83 RtlZeroMemory(EaInfo, EaLength);
84 EaInfo->EaNameLength = AFD_SOCKET_LENGTH;
85 RtlCopyMemory(EaInfo->EaName,
88 EaInfo->EaValueLength = sizeof(AFD_SOCKET_INFORMATION);
90 SocketInfo = (PAFD_SOCKET_INFORMATION)((ULONG_PTR)EaInfo->EaName + AFD_SOCKET_LENGTH);
91 SocketInfo->CommandChannel = FALSE;
92 SocketInfo->AddressFamily = AddressFamily;
93 SocketInfo->SocketType = SocketType;
94 SocketInfo->Protocol = Protocol;
95 SocketInfo->HelperContext = HelperContext;
96 SocketInfo->NotificationEvents = NotificationEvents;
97 /* Zeroed above so initialized to a wildcard address if a raw socket */
98 SocketInfo->Name.sa_family = AddressFamily;
100 /* Store TDI device name last in buffer */
101 SocketInfo->TdiDeviceName.Buffer = (PWCHAR)(EaInfo + EaShort);
102 SocketInfo->TdiDeviceName.MaximumLength = TdiDeviceName->Length + sizeof(WCHAR);
103 RtlCopyUnicodeString(&SocketInfo->TdiDeviceName, TdiDeviceName);
105 AFD_DbgPrint(MAX_TRACE, ("EaInfo at (0x%X) EaLength is (%d).\n", (UINT)EaInfo, (INT)EaLength));
107 RtlInitUnicodeStringFromLiteral(&DeviceName, L"\\Device\\Afd");
108 InitializeObjectAttributes(
115 Status = NtCreateFile(
117 FILE_GENERIC_READ | FILE_GENERIC_WRITE,
124 FILE_SYNCHRONOUS_IO_ALERT,
128 HeapFree(GlobalHeap, 0, EaInfo);
130 if (!NT_SUCCESS(Status)) {
131 AFD_DbgPrint(MIN_TRACE, ("Error opening device (Status 0x%X).\n",
133 return STATUS_INSUFFICIENT_RESOURCES;
136 *Socket = (SOCKET)FileHandle;
138 return STATUS_SUCCESS;
148 IN LPWSAPROTOCOL_INFOW lpProtocolInfo,
153 * FUNCTION: Creates a new socket
155 * af = Address family
157 * protocol = Protocol type
158 * lpProtocolInfo = Pointer to protocol information
160 * dwFlags = Socket flags
161 * lpErrno = Address of buffer for error information
163 * Created socket, or INVALID_SOCKET if it could not be created
166 WSAPROTOCOL_INFOW ProtocolInfo;
167 UNICODE_STRING TdiDeviceName;
168 DWORD NotificationEvents;
169 PWSHELPER_DLL HelperDLL;
179 AFD_DbgPrint(MAX_TRACE, ("af (%d) type (%d) protocol (%d).\n",
180 af, type, protocol));
182 if (!lpProtocolInfo) {
183 lpProtocolInfo = &ProtocolInfo;
184 ZeroMemory(&ProtocolInfo, sizeof(WSAPROTOCOL_INFOW));
186 ProtocolInfo.iAddressFamily = af;
187 ProtocolInfo.iSocketType = type;
188 ProtocolInfo.iProtocol = protocol;
191 HelperDLL = LocateHelperDLL(lpProtocolInfo);
193 *lpErrno = WSAEAFNOSUPPORT;
194 return INVALID_SOCKET;
197 AddressFamily = lpProtocolInfo->iAddressFamily;
198 SocketType = lpProtocolInfo->iSocketType;
199 Protocol = lpProtocolInfo->iProtocol;
201 Status = HelperDLL->EntryTable.lpWSHOpenSocket2(
209 &NotificationEvents);
210 if (Status != NO_ERROR) {
211 AFD_DbgPrint(MAX_TRACE, ("WinSock Helper DLL failed (0x%X).\n", Status));
213 return INVALID_SOCKET;
216 NtStatus = OpenSocket(&Socket,
224 RtlFreeUnicodeString(&TdiDeviceName);
225 if (!NT_SUCCESS(NtStatus)) {
226 *lpErrno = RtlNtStatusToDosError(Status);
227 return INVALID_SOCKET;
230 /* FIXME: Assumes catalog entry id to be 1 */
231 Socket2 = Upcalls.lpWPUModifyIFSHandle(1, Socket, lpErrno);
233 if (Socket2 == INVALID_SOCKET) {
235 AFD_DbgPrint(MIN_TRACE, ("FIXME: Cleanup.\n"));
236 return INVALID_SOCKET;
241 AFD_DbgPrint(MID_TRACE, ("Returning socket descriptor (0x%X).\n", Socket2));
253 * FUNCTION: Closes an open socket
255 * s = Socket descriptor
256 * lpErrno = Address of buffer for error information
258 * NO_ERROR, or SOCKET_ERROR if the socket could not be closed
263 AFD_DbgPrint(MAX_TRACE, ("s (0x%X).\n", s));
265 Status = NtClose((HANDLE)s);
267 if (NT_SUCCESS(Status)) {
272 *lpErrno = WSAENOTSOCK;
281 IN CONST LPSOCKADDR name,
285 * FUNCTION: Associates a local address with a socket
287 * s = Socket descriptor
288 * name = Pointer to local address
289 * namelen = Length of name
290 * lpErrno = Address of buffer for error information
292 * 0, or SOCKET_ERROR if the socket could not be bound
295 FILE_REQUEST_BIND Request;
296 FILE_REPLY_BIND Reply;
297 IO_STATUS_BLOCK Iosb;
300 AFD_DbgPrint(MAX_TRACE, ("s (0x%X) name (0x%X) namelen (%d).\n", s, name, namelen));
302 RtlCopyMemory(&Request.Name, name, sizeof(SOCKADDR));
304 Status = NtDeviceIoControlFile(
312 sizeof(FILE_REQUEST_BIND),
314 sizeof(FILE_REPLY_BIND));
315 if (Status == STATUS_PENDING) {
316 AFD_DbgPrint(MAX_TRACE, ("Waiting on transport.\n"));
317 /* FIXME: Wait only for blocking sockets */
318 Status = NtWaitForSingleObject((HANDLE)s, FALSE, NULL);
321 if (!NT_SUCCESS(Status)) {
322 *lpErrno = Reply.Status;
336 * FUNCTION: Listens for incoming connections
338 * s = Socket descriptor
339 * backlog = Maximum number of pending connection requests
340 * lpErrno = Address of buffer for error information
342 * 0, or SOCKET_ERROR if the socket could not be bound
345 FILE_REQUEST_LISTEN Request;
346 FILE_REPLY_LISTEN Reply;
347 IO_STATUS_BLOCK Iosb;
350 AFD_DbgPrint(MAX_TRACE, ("s (0x%X) backlog (%d).\n", s, backlog));
352 Request.Backlog = backlog;
354 Status = NtDeviceIoControlFile(
362 sizeof(FILE_REQUEST_LISTEN),
364 sizeof(FILE_REPLY_LISTEN));
365 if (Status == STATUS_PENDING) {
366 AFD_DbgPrint(MAX_TRACE, ("Waiting on transport.\n"));
367 /* FIXME: Wait only for blocking sockets */
368 Status = NtWaitForSingleObject((HANDLE)s, FALSE, NULL);
371 if (!NT_SUCCESS(Status)) {
372 *lpErrno = Reply.Status;
384 IN OUT LPFD_SET readfds,
385 IN OUT LPFD_SET writefds,
386 IN OUT LPFD_SET exceptfds,
387 IN CONST LPTIMEVAL timeout,
390 * FUNCTION: Returns status of one or more sockets
392 * nfds = Always ignored
393 * readfds = Pointer to socket set to be checked for readability (optional)
394 * writefds = Pointer to socket set to be checked for writability (optional)
395 * exceptfds = Pointer to socket set to be checked for errors (optional)
396 * timeout = Pointer to a TIMEVAL structure indicating maximum wait time
397 * (NULL means wait forever)
398 * lpErrno = Address of buffer for error information
400 * Number of ready socket descriptors, or SOCKET_ERROR if an error ocurred
403 PFILE_REQUEST_SELECT Request;
404 FILE_REPLY_SELECT Reply;
405 IO_STATUS_BLOCK Iosb;
413 AFD_DbgPrint(MAX_TRACE, ("readfds (0x%X) writefds (0x%X) exceptfds (0x%X).\n",
414 readfds, writefds, exceptfds));
416 /* FIXME: For now, all reads are timed out immediately */
417 if (readfds != NULL) {
418 AFD_DbgPrint(MID_TRACE, ("Timing out read query.\n"));
419 *lpErrno = WSAETIMEDOUT;
423 /* FIXME: For now, always allow write */
424 if (writefds != NULL) {
425 AFD_DbgPrint(MID_TRACE, ("Setting one socket writeable.\n"));
432 if ((readfds != NULL) && (readfds->fd_count > 0)) {
433 ReadSize = (readfds->fd_count * sizeof(SOCKET)) + sizeof(UINT);
437 if ((writefds != NULL) && (writefds->fd_count > 0)) {
438 WriteSize = (writefds->fd_count * sizeof(SOCKET)) + sizeof(UINT);
442 if ((exceptfds != NULL) && (exceptfds->fd_count > 0)) {
443 ExceptSize = (exceptfds->fd_count * sizeof(SOCKET)) + sizeof(UINT);
446 Size = ReadSize + WriteSize + ExceptSize;
448 Request = (PFILE_REQUEST_SELECT)HeapAlloc(
449 GlobalHeap, 0, sizeof(FILE_REQUEST_SELECT) + Size);
451 *lpErrno = WSAENOBUFS;
455 /* Put FD SETs after request structure */
456 Current = (Request + 1);
459 Request->ReadFDSet = (LPFD_SET)Current;
461 RtlCopyMemory(Request->ReadFDSet, readfds, ReadSize);
463 Request->ReadFDSet = NULL;
467 Request->WriteFDSet = (LPFD_SET)Current;
468 Current += WriteSize;
469 RtlCopyMemory(Request->WriteFDSet, writefds, WriteSize);
471 Request->WriteFDSet = NULL;
474 if (ExceptSize > 0) {
475 Request->ExceptFDSet = (LPFD_SET)Current;
476 RtlCopyMemory(Request->ExceptFDSet, exceptfds, ExceptSize);
478 Request->ExceptFDSet = NULL;
481 AFD_DbgPrint(MAX_TRACE, ("R1 (0x%X) W1 (0x%X).\n", Request->ReadFDSet, Request->WriteFDSet));
483 Status = NtDeviceIoControlFile(
491 sizeof(FILE_REQUEST_SELECT) + Size,
493 sizeof(FILE_REPLY_SELECT));
495 HeapFree(GlobalHeap, 0, Request);
497 if (Status == STATUS_PENDING) {
498 AFD_DbgPrint(MAX_TRACE, ("Waiting on transport.\n"));
499 /* FIXME: Wait only for blocking sockets */
500 Status = NtWaitForSingleObject(CommandChannel, FALSE, NULL);
503 if (!NT_SUCCESS(Status)) {
504 AFD_DbgPrint(MAX_TRACE, ("Status (0x%X).\n", Status));
505 *lpErrno = WSAENOBUFS;
509 AFD_DbgPrint(MAX_TRACE, ("Select successful. Status (0x%X) Count (0x%X).\n",
510 Reply.Status, Reply.SocketCount));
512 *lpErrno = Reply.Status;
514 return Reply.SocketCount;
522 IN OUT LPINT addrlen,
523 IN LPCONDITIONPROC lpfnCondition,
524 IN DWORD dwCallbackData,
527 FILE_REQUEST_ACCEPT Request;
528 FILE_REPLY_ACCEPT Reply;
529 IO_STATUS_BLOCK Iosb;
532 AFD_DbgPrint(MAX_TRACE, ("s (0x%X).\n", s));
535 Request.addrlen = *addrlen;
536 Request.lpfnCondition = lpfnCondition;
537 Request.dwCallbackData = dwCallbackData;
539 Status = NtDeviceIoControlFile(
547 sizeof(FILE_REQUEST_ACCEPT),
549 sizeof(FILE_REPLY_ACCEPT));
550 if (Status == STATUS_PENDING) {
551 AFD_DbgPrint(MAX_TRACE, ("Waiting on transport.\n"));
552 /* FIXME: Wait only for blocking sockets */
553 Status = NtWaitForSingleObject((HANDLE)s, FALSE, NULL);
556 if (!NT_SUCCESS(Status)) {
557 *lpErrno = Reply.Status;
558 return INVALID_SOCKET;
561 *addrlen = Reply.addrlen;
571 IN CONST LPSOCKADDR name,
573 IN LPWSABUF lpCallerData,
574 OUT LPWSABUF lpCalleeData,
579 FILE_REQUEST_CONNECT Request;
580 FILE_REPLY_CONNECT Reply;
581 IO_STATUS_BLOCK Iosb;
584 AFD_DbgPrint(MAX_TRACE, ("s (0x%X).\n", s));
587 Request.namelen = namelen;
588 Request.lpCallerData = lpCallerData;
589 Request.lpCalleeData = lpCalleeData;
590 Request.lpSQOS = lpSQOS;
591 Request.lpGQOS = lpGQOS;
593 Status = NtDeviceIoControlFile(
601 sizeof(FILE_REQUEST_CONNECT),
603 sizeof(FILE_REPLY_CONNECT));
604 if (Status == STATUS_PENDING) {
605 AFD_DbgPrint(MAX_TRACE, ("Waiting on transport.\n"));
606 /* FIXME: Wait only for blocking sockets */
607 Status = NtWaitForSingleObject((HANDLE)s, FALSE, NULL);
610 if (!NT_SUCCESS(Status)) {
611 *lpErrno = Reply.Status;
612 return INVALID_SOCKET;
619 NTSTATUS OpenCommandChannel(
622 * FUNCTION: Opens a command channel to afd.sys
626 * Status of operation
629 OBJECT_ATTRIBUTES ObjectAttributes;
630 PAFD_SOCKET_INFORMATION SocketInfo;
631 PFILE_FULL_EA_INFORMATION EaInfo;
632 UNICODE_STRING DeviceName;
633 IO_STATUS_BLOCK Iosb;
639 AFD_DbgPrint(MAX_TRACE, ("Called\n"));
641 EaShort = sizeof(FILE_FULL_EA_INFORMATION) +
643 sizeof(AFD_SOCKET_INFORMATION);
647 EaInfo = (PFILE_FULL_EA_INFORMATION)HeapAlloc(GlobalHeap, 0, EaLength);
649 return STATUS_INSUFFICIENT_RESOURCES;
652 RtlZeroMemory(EaInfo, EaLength);
653 EaInfo->EaNameLength = AFD_SOCKET_LENGTH;
654 RtlCopyMemory(EaInfo->EaName,
657 EaInfo->EaValueLength = sizeof(AFD_SOCKET_INFORMATION);
659 SocketInfo = (PAFD_SOCKET_INFORMATION)((ULONG_PTR)EaInfo->EaName + AFD_SOCKET_LENGTH);
660 SocketInfo->CommandChannel = TRUE;
662 RtlInitUnicodeStringFromLiteral(&DeviceName, L"\\Device\\Afd");
663 InitializeObjectAttributes(
670 Status = NtCreateFile(
672 FILE_GENERIC_READ | FILE_GENERIC_WRITE,
679 FILE_SYNCHRONOUS_IO_ALERT,
683 if (!NT_SUCCESS(Status)) {
684 AFD_DbgPrint(MIN_TRACE, ("Error opening device (Status 0x%X).\n",
689 CommandChannel = FileHandle;
691 return STATUS_SUCCESS;
695 NTSTATUS CloseCommandChannel(
698 * FUNCTION: Closes command channel to afd.sys
702 * Status of operation
705 AFD_DbgPrint(MAX_TRACE, ("Called.\n"));
707 return NtClose(CommandChannel);
714 IN WORD wVersionRequested,
715 OUT LPWSPDATA lpWSPData,
716 IN LPWSAPROTOCOL_INFOW lpProtocolInfo,
717 IN WSPUPCALLTABLE UpcallTable,
718 OUT LPWSPPROC_TABLE lpProcTable)
720 * FUNCTION: Initialize service provider for a client
722 * wVersionRequested = Highest WinSock SPI version that the caller can use
723 * lpWSPData = Address of WSPDATA structure to initialize
724 * lpProtocolInfo = Pointer to structure that defines the desired protocol
725 * UpcallTable = Pointer to upcall table of the WinSock DLL
726 * lpProcTable = Address of procedure table to initialize
728 * Status of operation
734 AFD_DbgPrint(MAX_TRACE, ("wVersionRequested (0x%X) \n", wVersionRequested));
736 EnterCriticalSection(&InitCriticalSection);
738 Upcalls = UpcallTable;
740 if (StartupCount == 0) {
741 /* First time called */
743 Status = OpenCommandChannel();
744 if (NT_SUCCESS(Status)) {
745 hWS2_32 = GetModuleHandle(L"ws2_32.dll");
746 if (hWS2_32 != NULL) {
747 lpWPUCompleteOverlappedRequest = (LPWPUCOMPLETEOVERLAPPEDREQUEST)
748 GetProcAddress(hWS2_32, "WPUCompleteOverlappedRequest");
749 if (lpWPUCompleteOverlappedRequest != NULL) {
754 AFD_DbgPrint(MIN_TRACE, ("GetModuleHandle() failed for ws2_32.dll\n"));
757 AFD_DbgPrint(MIN_TRACE, ("Cannot open afd.sys\n"));
764 LeaveCriticalSection(&InitCriticalSection);
766 if (Status == NO_ERROR) {
767 lpProcTable->lpWSPAccept = WSPAccept;
768 lpProcTable->lpWSPAddressToString = WSPAddressToString;
769 lpProcTable->lpWSPAsyncSelect = WSPAsyncSelect;
770 lpProcTable->lpWSPBind = WSPBind;
771 lpProcTable->lpWSPCancelBlockingCall = WSPCancelBlockingCall;
772 lpProcTable->lpWSPCleanup = WSPCleanup;
773 lpProcTable->lpWSPCloseSocket = WSPCloseSocket;
774 lpProcTable->lpWSPConnect = WSPConnect;
775 lpProcTable->lpWSPDuplicateSocket = WSPDuplicateSocket;
776 lpProcTable->lpWSPEnumNetworkEvents = WSPEnumNetworkEvents;
777 lpProcTable->lpWSPEventSelect = WSPEventSelect;
778 lpProcTable->lpWSPGetOverlappedResult = WSPGetOverlappedResult;
779 lpProcTable->lpWSPGetPeerName = WSPGetPeerName;
780 lpProcTable->lpWSPGetSockName = WSPGetSockName;
781 lpProcTable->lpWSPGetSockOpt = WSPGetSockOpt;
782 lpProcTable->lpWSPGetQOSByName = WSPGetQOSByName;
783 lpProcTable->lpWSPIoctl = WSPIoctl;
784 lpProcTable->lpWSPJoinLeaf = WSPJoinLeaf;
785 lpProcTable->lpWSPListen = WSPListen;
786 lpProcTable->lpWSPRecv = WSPRecv;
787 lpProcTable->lpWSPRecvDisconnect = WSPRecvDisconnect;
788 lpProcTable->lpWSPRecvFrom = WSPRecvFrom;
789 lpProcTable->lpWSPSelect = WSPSelect;
790 lpProcTable->lpWSPSend = WSPSend;
791 lpProcTable->lpWSPSendDisconnect = WSPSendDisconnect;
792 lpProcTable->lpWSPSendTo = WSPSendTo;
793 lpProcTable->lpWSPSetSockOpt = WSPSetSockOpt;
794 lpProcTable->lpWSPShutdown = WSPShutdown;
795 lpProcTable->lpWSPSocket = WSPSocket;
796 lpProcTable->lpWSPStringToAddress = WSPStringToAddress;
798 lpWSPData->wVersion = MAKEWORD(2, 2);
799 lpWSPData->wHighVersion = MAKEWORD(2, 2);
802 AFD_DbgPrint(MAX_TRACE, ("Status (%d).\n", Status));
813 * FUNCTION: Cleans up service provider for a client
815 * lpErrno = Address of buffer for error information
817 * 0 if successful, or SOCKET_ERROR if not
820 AFD_DbgPrint(MAX_TRACE, ("\n"));
822 EnterCriticalSection(&InitCriticalSection);
824 if (StartupCount > 0) {
827 if (StartupCount == 0) {
828 AFD_DbgPrint(MAX_TRACE, ("Cleaning up msafd.dll.\n"));
830 CloseCommandChannel();
834 LeaveCriticalSection(&InitCriticalSection);
836 AFD_DbgPrint(MAX_TRACE, ("Leaving.\n"));
846 DllMain(HANDLE hInstDll,
850 AFD_DbgPrint(MAX_TRACE, ("DllMain of msafd.dll\n"));
853 case DLL_PROCESS_ATTACH:
854 /* Don't need thread attach notifications
855 so disable them to improve performance */
856 DisableThreadLibraryCalls(hInstDll);
858 InitializeCriticalSection(&InitCriticalSection);
860 GlobalHeap = GetProcessHeap();
862 CreateHelperDLLDatabase();
865 case DLL_THREAD_ATTACH:
868 case DLL_THREAD_DETACH:
871 case DLL_PROCESS_DETACH:
873 DestroyHelperDLLDatabase();
875 DeleteCriticalSection(&InitCriticalSection);
880 AFD_DbgPrint(MAX_TRACE, ("DllMain of msafd.dll (leaving)\n"));