2 * COPYRIGHT: See COPYING in the top level directory
3 * PROJECT: ReactOS Ancillary Function Driver DLL
5 * PURPOSE: DLL entry point
6 * PROGRAMMERS: Casper S. Hornstrup (chorns@users.sourceforge.net)
8 * CSH 01/09-2000 Created
15 /* See debug.h for debug/trace constants */
16 DWORD DebugTraceLevel = MIN_TRACE;
17 //DWORD DebugTraceLevel = DEBUG_ULTRA;
21 /* To make the linker happy */
22 VOID STDCALL KeBugCheck (ULONG BugCheckCode) {}
26 WSPUPCALLTABLE Upcalls;
27 LPWPUCOMPLETEOVERLAPPEDREQUEST lpWPUCompleteOverlappedRequest;
28 CRITICAL_SECTION InitCriticalSection;
29 DWORD StartupCount = 0;
30 HANDLE CommandChannel;
39 DWORD NotificationEvents,
40 PUNICODE_STRING TdiDeviceName)
42 * FUNCTION: Opens a socket
44 * Socket = Address of buffer to place socket descriptor
45 * AddressFamily = Address family
46 * SocketType = Type of socket
47 * Protocol = Protocol type
48 * HelperContext = Pointer to context information for helper DLL
49 * NotificationEvents = Events for which helper DLL is to be notified
50 * TdiDeviceName = Pointer to name of TDI device to use
55 OBJECT_ATTRIBUTES ObjectAttributes;
56 PAFD_SOCKET_INFORMATION SocketInfo;
57 PFILE_FULL_EA_INFORMATION EaInfo;
58 UNICODE_STRING DeviceName;
65 AFD_DbgPrint(MAX_TRACE, ("Socket (0x%X) TdiDeviceName (%wZ)\n",
66 Socket, TdiDeviceName));
68 AFD_DbgPrint(MAX_TRACE, ("Socket2 (0x%X) TdiDeviceName (%S)\n",
69 Socket, TdiDeviceName->Buffer));
71 EaShort = sizeof(FILE_FULL_EA_INFORMATION) +
73 sizeof(AFD_SOCKET_INFORMATION);
75 EaLength = EaShort + TdiDeviceName->Length + sizeof(WCHAR);
77 EaInfo = (PFILE_FULL_EA_INFORMATION)HeapAlloc(GlobalHeap, 0, EaLength);
79 return STATUS_INSUFFICIENT_RESOURCES;
82 RtlZeroMemory(EaInfo, EaLength);
83 EaInfo->EaNameLength = AFD_SOCKET_LENGTH;
84 RtlCopyMemory(EaInfo->EaName,
87 EaInfo->EaValueLength = sizeof(AFD_SOCKET_INFORMATION);
89 SocketInfo = (PAFD_SOCKET_INFORMATION)((ULONG_PTR)EaInfo->EaName + AFD_SOCKET_LENGTH);
90 SocketInfo->CommandChannel = FALSE;
91 SocketInfo->AddressFamily = AddressFamily;
92 SocketInfo->SocketType = SocketType;
93 SocketInfo->Protocol = Protocol;
94 SocketInfo->HelperContext = HelperContext;
95 SocketInfo->NotificationEvents = NotificationEvents;
96 /* Zeroed above so initialized to a wildcard address if a raw socket */
97 SocketInfo->Name.sa_family = AddressFamily;
99 /* Store TDI device name last in buffer */
100 SocketInfo->TdiDeviceName.Buffer = (PWCHAR)(EaInfo + EaShort);
101 SocketInfo->TdiDeviceName.MaximumLength = TdiDeviceName->Length + sizeof(WCHAR);
102 RtlCopyUnicodeString(&SocketInfo->TdiDeviceName, TdiDeviceName);
104 AFD_DbgPrint(MAX_TRACE, ("EaInfo at (0x%X) EaLength is (%d).\n", (UINT)EaInfo, (INT)EaLength));
106 RtlInitUnicodeStringFromLiteral(&DeviceName, L"\\Device\\Afd");
107 InitializeObjectAttributes(
114 Status = NtCreateFile(
116 FILE_GENERIC_READ | FILE_GENERIC_WRITE,
123 FILE_SYNCHRONOUS_IO_ALERT,
127 HeapFree(GlobalHeap, 0, EaInfo);
129 if (!NT_SUCCESS(Status)) {
130 AFD_DbgPrint(MIN_TRACE, ("Error opening device (Status 0x%X).\n",
132 return STATUS_INSUFFICIENT_RESOURCES;
135 *Socket = (SOCKET)FileHandle;
137 return STATUS_SUCCESS;
147 IN LPWSAPROTOCOL_INFOW lpProtocolInfo,
152 * FUNCTION: Creates a new socket
154 * af = Address family
156 * protocol = Protocol type
157 * lpProtocolInfo = Pointer to protocol information
159 * dwFlags = Socket flags
160 * lpErrno = Address of buffer for error information
162 * Created socket, or INVALID_SOCKET if it could not be created
165 WSAPROTOCOL_INFOW ProtocolInfo;
166 UNICODE_STRING TdiDeviceName;
167 DWORD NotificationEvents;
168 PWSHELPER_DLL HelperDLL;
178 AFD_DbgPrint(MAX_TRACE, ("af (%d) type (%d) protocol (%d).\n",
179 af, type, protocol));
181 if (!lpProtocolInfo) {
182 lpProtocolInfo = &ProtocolInfo;
183 ZeroMemory(&ProtocolInfo, sizeof(WSAPROTOCOL_INFOW));
185 ProtocolInfo.iAddressFamily = af;
186 ProtocolInfo.iSocketType = type;
187 ProtocolInfo.iProtocol = protocol;
190 HelperDLL = LocateHelperDLL(lpProtocolInfo);
192 *lpErrno = WSAEAFNOSUPPORT;
193 return INVALID_SOCKET;
196 AddressFamily = lpProtocolInfo->iAddressFamily;
197 SocketType = lpProtocolInfo->iSocketType;
198 Protocol = lpProtocolInfo->iProtocol;
200 Status = HelperDLL->EntryTable.lpWSHOpenSocket2(
208 &NotificationEvents);
209 if (Status != NO_ERROR) {
210 AFD_DbgPrint(MAX_TRACE, ("WinSock Helper DLL failed (0x%X).\n", Status));
212 return INVALID_SOCKET;
215 NtStatus = OpenSocket(&Socket,
223 RtlFreeUnicodeString(&TdiDeviceName);
224 if (!NT_SUCCESS(NtStatus)) {
225 *lpErrno = RtlNtStatusToDosError(Status);
226 return INVALID_SOCKET;
229 /* FIXME: Assumes catalog entry id to be 1 */
230 Socket2 = Upcalls.lpWPUModifyIFSHandle(1, Socket, lpErrno);
232 if (Socket2 == INVALID_SOCKET) {
234 AFD_DbgPrint(MIN_TRACE, ("FIXME: Cleanup.\n"));
235 return INVALID_SOCKET;
240 AFD_DbgPrint(MID_TRACE, ("Returning socket descriptor (0x%X).\n", Socket2));
252 * FUNCTION: Closes an open socket
254 * s = Socket descriptor
255 * lpErrno = Address of buffer for error information
257 * NO_ERROR, or SOCKET_ERROR if the socket could not be closed
262 AFD_DbgPrint(MAX_TRACE, ("s (0x%X).\n", s));
264 Status = NtClose((HANDLE)s);
266 if (NT_SUCCESS(Status)) {
271 *lpErrno = WSAENOTSOCK;
280 IN CONST LPSOCKADDR name,
284 * FUNCTION: Associates a local address with a socket
286 * s = Socket descriptor
287 * name = Pointer to local address
288 * namelen = Length of name
289 * lpErrno = Address of buffer for error information
291 * 0, or SOCKET_ERROR if the socket could not be bound
294 FILE_REQUEST_BIND Request;
295 FILE_REPLY_BIND Reply;
296 IO_STATUS_BLOCK Iosb;
299 AFD_DbgPrint(MAX_TRACE, ("s (0x%X) name (0x%X) namelen (%d).\n", s, name, namelen));
301 RtlCopyMemory(&Request.Name, name, sizeof(SOCKADDR));
303 Status = NtDeviceIoControlFile(
311 sizeof(FILE_REQUEST_BIND),
313 sizeof(FILE_REPLY_BIND));
314 if (Status == STATUS_PENDING) {
315 AFD_DbgPrint(MAX_TRACE, ("Waiting on transport.\n"));
316 /* FIXME: Wait only for blocking sockets */
317 Status = NtWaitForSingleObject((HANDLE)s, FALSE, NULL);
320 if (!NT_SUCCESS(Status)) {
321 *lpErrno = Reply.Status;
335 * FUNCTION: Listens for incoming connections
337 * s = Socket descriptor
338 * backlog = Maximum number of pending connection requests
339 * lpErrno = Address of buffer for error information
341 * 0, or SOCKET_ERROR if the socket could not be bound
344 FILE_REQUEST_LISTEN Request;
345 FILE_REPLY_LISTEN Reply;
346 IO_STATUS_BLOCK Iosb;
349 AFD_DbgPrint(MAX_TRACE, ("s (0x%X) backlog (%d).\n", s, backlog));
351 Request.Backlog = backlog;
353 Status = NtDeviceIoControlFile(
361 sizeof(FILE_REQUEST_LISTEN),
363 sizeof(FILE_REPLY_LISTEN));
364 if (Status == STATUS_PENDING) {
365 AFD_DbgPrint(MAX_TRACE, ("Waiting on transport.\n"));
366 /* FIXME: Wait only for blocking sockets */
367 Status = NtWaitForSingleObject((HANDLE)s, FALSE, NULL);
370 if (!NT_SUCCESS(Status)) {
371 *lpErrno = Reply.Status;
383 IN OUT LPFD_SET readfds,
384 IN OUT LPFD_SET writefds,
385 IN OUT LPFD_SET exceptfds,
386 IN CONST LPTIMEVAL timeout,
389 * FUNCTION: Returns status of one or more sockets
391 * nfds = Always ignored
392 * readfds = Pointer to socket set to be checked for readability (optional)
393 * writefds = Pointer to socket set to be checked for writability (optional)
394 * exceptfds = Pointer to socket set to be checked for errors (optional)
395 * timeout = Pointer to a TIMEVAL structure indicating maximum wait time
396 * (NULL means wait forever)
397 * lpErrno = Address of buffer for error information
399 * Number of ready socket descriptors, or SOCKET_ERROR if an error ocurred
402 PFILE_REQUEST_SELECT Request;
403 FILE_REPLY_SELECT Reply;
404 IO_STATUS_BLOCK Iosb;
412 AFD_DbgPrint(MAX_TRACE, ("readfds (0x%X) writefds (0x%X) exceptfds (0x%X).\n",
413 readfds, writefds, exceptfds));
415 /* FIXME: For now, all reads are timed out immediately */
416 if (readfds != NULL) {
417 AFD_DbgPrint(MID_TRACE, ("Timing out read query.\n"));
418 *lpErrno = WSAETIMEDOUT;
422 /* FIXME: For now, always allow write */
423 if (writefds != NULL) {
424 AFD_DbgPrint(MID_TRACE, ("Setting one socket writeable.\n"));
431 if ((readfds != NULL) && (readfds->fd_count > 0)) {
432 ReadSize = (readfds->fd_count * sizeof(SOCKET)) + sizeof(UINT);
436 if ((writefds != NULL) && (writefds->fd_count > 0)) {
437 WriteSize = (writefds->fd_count * sizeof(SOCKET)) + sizeof(UINT);
441 if ((exceptfds != NULL) && (exceptfds->fd_count > 0)) {
442 ExceptSize = (exceptfds->fd_count * sizeof(SOCKET)) + sizeof(UINT);
445 Size = ReadSize + WriteSize + ExceptSize;
447 Request = (PFILE_REQUEST_SELECT)HeapAlloc(
448 GlobalHeap, 0, sizeof(FILE_REQUEST_SELECT) + Size);
450 *lpErrno = WSAENOBUFS;
454 /* Put FD SETs after request structure */
455 Current = (Request + 1);
458 Request->ReadFDSet = (LPFD_SET)Current;
460 RtlCopyMemory(Request->ReadFDSet, readfds, ReadSize);
462 Request->ReadFDSet = NULL;
466 Request->WriteFDSet = (LPFD_SET)Current;
467 Current += WriteSize;
468 RtlCopyMemory(Request->WriteFDSet, writefds, WriteSize);
470 Request->WriteFDSet = NULL;
473 if (ExceptSize > 0) {
474 Request->ExceptFDSet = (LPFD_SET)Current;
475 RtlCopyMemory(Request->ExceptFDSet, exceptfds, ExceptSize);
477 Request->ExceptFDSet = NULL;
480 AFD_DbgPrint(MAX_TRACE, ("R1 (0x%X) W1 (0x%X).\n", Request->ReadFDSet, Request->WriteFDSet));
482 Status = NtDeviceIoControlFile(
490 sizeof(FILE_REQUEST_SELECT) + Size,
492 sizeof(FILE_REPLY_SELECT));
494 HeapFree(GlobalHeap, 0, Request);
496 if (Status == STATUS_PENDING) {
497 AFD_DbgPrint(MAX_TRACE, ("Waiting on transport.\n"));
498 /* FIXME: Wait only for blocking sockets */
499 Status = NtWaitForSingleObject(CommandChannel, FALSE, NULL);
502 if (!NT_SUCCESS(Status)) {
503 AFD_DbgPrint(MAX_TRACE, ("Status (0x%X).\n", Status));
504 *lpErrno = WSAENOBUFS;
508 AFD_DbgPrint(MAX_TRACE, ("Select successful. Status (0x%X) Count (0x%X).\n",
509 Reply.Status, Reply.SocketCount));
511 *lpErrno = Reply.Status;
513 return Reply.SocketCount;
521 IN OUT LPINT addrlen,
522 IN LPCONDITIONPROC lpfnCondition,
523 IN DWORD dwCallbackData,
526 FILE_REQUEST_ACCEPT Request;
527 FILE_REPLY_ACCEPT Reply;
528 IO_STATUS_BLOCK Iosb;
531 AFD_DbgPrint(MAX_TRACE, ("s (0x%X).\n", s));
534 Request.addrlen = *addrlen;
535 Request.lpfnCondition = lpfnCondition;
536 Request.dwCallbackData = dwCallbackData;
538 Status = NtDeviceIoControlFile(
546 sizeof(FILE_REQUEST_ACCEPT),
548 sizeof(FILE_REPLY_ACCEPT));
549 if (Status == STATUS_PENDING) {
550 AFD_DbgPrint(MAX_TRACE, ("Waiting on transport.\n"));
551 /* FIXME: Wait only for blocking sockets */
552 Status = NtWaitForSingleObject((HANDLE)s, FALSE, NULL);
555 if (!NT_SUCCESS(Status)) {
556 *lpErrno = Reply.Status;
557 return INVALID_SOCKET;
560 *addrlen = Reply.addrlen;
570 IN CONST LPSOCKADDR name,
572 IN LPWSABUF lpCallerData,
573 OUT LPWSABUF lpCalleeData,
578 FILE_REQUEST_CONNECT Request;
579 FILE_REPLY_CONNECT Reply;
580 IO_STATUS_BLOCK Iosb;
583 AFD_DbgPrint(MAX_TRACE, ("s (0x%X).\n", s));
586 Request.namelen = namelen;
587 Request.lpCallerData = lpCallerData;
588 Request.lpCalleeData = lpCalleeData;
589 Request.lpSQOS = lpSQOS;
590 Request.lpGQOS = lpGQOS;
592 Status = NtDeviceIoControlFile(
600 sizeof(FILE_REQUEST_CONNECT),
602 sizeof(FILE_REPLY_CONNECT));
603 if (Status == STATUS_PENDING) {
604 AFD_DbgPrint(MAX_TRACE, ("Waiting on transport.\n"));
605 /* FIXME: Wait only for blocking sockets */
606 Status = NtWaitForSingleObject((HANDLE)s, FALSE, NULL);
609 if (!NT_SUCCESS(Status)) {
610 *lpErrno = Reply.Status;
611 return INVALID_SOCKET;
618 NTSTATUS OpenCommandChannel(
621 * FUNCTION: Opens a command channel to afd.sys
625 * Status of operation
628 OBJECT_ATTRIBUTES ObjectAttributes;
629 PAFD_SOCKET_INFORMATION SocketInfo;
630 PFILE_FULL_EA_INFORMATION EaInfo;
631 UNICODE_STRING DeviceName;
632 IO_STATUS_BLOCK Iosb;
638 AFD_DbgPrint(MAX_TRACE, ("Called\n"));
640 EaShort = sizeof(FILE_FULL_EA_INFORMATION) +
642 sizeof(AFD_SOCKET_INFORMATION);
646 EaInfo = (PFILE_FULL_EA_INFORMATION)HeapAlloc(GlobalHeap, 0, EaLength);
648 return STATUS_INSUFFICIENT_RESOURCES;
651 RtlZeroMemory(EaInfo, EaLength);
652 EaInfo->EaNameLength = AFD_SOCKET_LENGTH;
653 RtlCopyMemory(EaInfo->EaName,
656 EaInfo->EaValueLength = sizeof(AFD_SOCKET_INFORMATION);
658 SocketInfo = (PAFD_SOCKET_INFORMATION)((ULONG_PTR)EaInfo->EaName + AFD_SOCKET_LENGTH);
659 SocketInfo->CommandChannel = TRUE;
661 RtlInitUnicodeStringFromLiteral(&DeviceName, L"\\Device\\Afd");
662 InitializeObjectAttributes(
669 Status = NtCreateFile(
671 FILE_GENERIC_READ | FILE_GENERIC_WRITE,
678 FILE_SYNCHRONOUS_IO_ALERT,
682 if (!NT_SUCCESS(Status)) {
683 AFD_DbgPrint(MIN_TRACE, ("Error opening device (Status 0x%X).\n",
688 CommandChannel = FileHandle;
690 return STATUS_SUCCESS;
694 NTSTATUS CloseCommandChannel(
697 * FUNCTION: Closes command channel to afd.sys
701 * Status of operation
704 AFD_DbgPrint(MAX_TRACE, ("Called.\n"));
706 return NtClose(CommandChannel);
713 IN WORD wVersionRequested,
714 OUT LPWSPDATA lpWSPData,
715 IN LPWSAPROTOCOL_INFOW lpProtocolInfo,
716 IN WSPUPCALLTABLE UpcallTable,
717 OUT LPWSPPROC_TABLE lpProcTable)
719 * FUNCTION: Initialize service provider for a client
721 * wVersionRequested = Highest WinSock SPI version that the caller can use
722 * lpWSPData = Address of WSPDATA structure to initialize
723 * lpProtocolInfo = Pointer to structure that defines the desired protocol
724 * UpcallTable = Pointer to upcall table of the WinSock DLL
725 * lpProcTable = Address of procedure table to initialize
727 * Status of operation
733 AFD_DbgPrint(MAX_TRACE, ("wVersionRequested (0x%X) \n", wVersionRequested));
735 EnterCriticalSection(&InitCriticalSection);
737 Upcalls = UpcallTable;
739 if (StartupCount == 0) {
740 /* First time called */
742 Status = OpenCommandChannel();
743 if (NT_SUCCESS(Status)) {
744 hWS2_32 = GetModuleHandle(L"ws2_32.dll");
745 if (hWS2_32 != NULL) {
746 lpWPUCompleteOverlappedRequest = (LPWPUCOMPLETEOVERLAPPEDREQUEST)
747 GetProcAddress(hWS2_32, "WPUCompleteOverlappedRequest");
748 if (lpWPUCompleteOverlappedRequest != NULL) {
753 AFD_DbgPrint(MIN_TRACE, ("GetModuleHandle() failed for ws2_32.dll\n"));
756 AFD_DbgPrint(MIN_TRACE, ("Cannot open afd.sys\n"));
763 LeaveCriticalSection(&InitCriticalSection);
765 if (Status == NO_ERROR) {
766 lpProcTable->lpWSPAccept = WSPAccept;
767 lpProcTable->lpWSPAddressToString = WSPAddressToString;
768 lpProcTable->lpWSPAsyncSelect = WSPAsyncSelect;
769 lpProcTable->lpWSPBind = WSPBind;
770 lpProcTable->lpWSPCancelBlockingCall = WSPCancelBlockingCall;
771 lpProcTable->lpWSPCleanup = WSPCleanup;
772 lpProcTable->lpWSPCloseSocket = WSPCloseSocket;
773 lpProcTable->lpWSPConnect = WSPConnect;
774 lpProcTable->lpWSPDuplicateSocket = WSPDuplicateSocket;
775 lpProcTable->lpWSPEnumNetworkEvents = WSPEnumNetworkEvents;
776 lpProcTable->lpWSPEventSelect = WSPEventSelect;
777 lpProcTable->lpWSPGetOverlappedResult = WSPGetOverlappedResult;
778 lpProcTable->lpWSPGetPeerName = WSPGetPeerName;
779 lpProcTable->lpWSPGetSockName = WSPGetSockName;
780 lpProcTable->lpWSPGetSockOpt = WSPGetSockOpt;
781 lpProcTable->lpWSPGetQOSByName = WSPGetQOSByName;
782 lpProcTable->lpWSPIoctl = WSPIoctl;
783 lpProcTable->lpWSPJoinLeaf = WSPJoinLeaf;
784 lpProcTable->lpWSPListen = WSPListen;
785 lpProcTable->lpWSPRecv = WSPRecv;
786 lpProcTable->lpWSPRecvDisconnect = WSPRecvDisconnect;
787 lpProcTable->lpWSPRecvFrom = WSPRecvFrom;
788 lpProcTable->lpWSPSelect = WSPSelect;
789 lpProcTable->lpWSPSend = WSPSend;
790 lpProcTable->lpWSPSendDisconnect = WSPSendDisconnect;
791 lpProcTable->lpWSPSendTo = WSPSendTo;
792 lpProcTable->lpWSPSetSockOpt = WSPSetSockOpt;
793 lpProcTable->lpWSPShutdown = WSPShutdown;
794 lpProcTable->lpWSPSocket = WSPSocket;
795 lpProcTable->lpWSPStringToAddress = WSPStringToAddress;
797 lpWSPData->wVersion = MAKEWORD(2, 2);
798 lpWSPData->wHighVersion = MAKEWORD(2, 2);
801 AFD_DbgPrint(MAX_TRACE, ("Status (%d).\n", Status));
812 * FUNCTION: Cleans up service provider for a client
814 * lpErrno = Address of buffer for error information
816 * 0 if successful, or SOCKET_ERROR if not
819 AFD_DbgPrint(MAX_TRACE, ("\n"));
821 EnterCriticalSection(&InitCriticalSection);
823 if (StartupCount > 0) {
826 if (StartupCount == 0) {
827 AFD_DbgPrint(MAX_TRACE, ("Cleaning up msafd.dll.\n"));
829 CloseCommandChannel();
833 LeaveCriticalSection(&InitCriticalSection);
835 AFD_DbgPrint(MAX_TRACE, ("Leaving.\n"));
845 DllMain(HANDLE hInstDll,
849 AFD_DbgPrint(MAX_TRACE, ("DllMain of msafd.dll\n"));
852 case DLL_PROCESS_ATTACH:
853 /* Don't need thread attach notifications
854 so disable them to improve performance */
855 DisableThreadLibraryCalls(hInstDll);
857 InitializeCriticalSection(&InitCriticalSection);
859 GlobalHeap = GetProcessHeap();
861 CreateHelperDLLDatabase();
864 case DLL_THREAD_ATTACH:
867 case DLL_THREAD_DETACH:
870 case DLL_PROCESS_DETACH:
872 DestroyHelperDLLDatabase();
874 DeleteCriticalSection(&InitCriticalSection);
879 AFD_DbgPrint(MAX_TRACE, ("DllMain of msafd.dll (leaving)\n"));