[MSAFD][AFD] Asynchronous connection support (#8573)

* [MSAFD] async support

- Allow APCs to be processed while waiting for blocking IO
- Remove SleepEx hacks

[MSAFD][AFD][TCPIP] Patch CORE-12152

Original author: Peter Hater <7element@mail.bg>

TCPConnect: Always copy bind address into connection.

[MSAFD] SetSocketInformation: Remove unnecessary TRACE code blocks

[MSAFD] AfdRecvAPC and AfdSendAPC: Remove unnecessary NULL checks

[MSAFD] WSPRecv, WSPRecvFrom and WSPSend: Fix leaks and fix IOCTL input buffer handling

[MSAFD] Rewrite WSPSendTo

[MSAFD] Rewrite WSPConnect and AfdConnectAPC

[MSAFD] Remove hacky SleepEx calls

[MSAFD] Don't wait with no timeout

[MSAFD] send/receive: Fixes for overlapped socket

- Fix IOCTL buffers handling
- Allow APC to be processed

[MSAFD] connect: Fixes for async connect

- Allow APC to be processed
- Remove "mark as connected" hack for non-blocking socket

[MSAFD] Get/SetSocketInformation: Fixes for overlapped socket

- Fix IOCTL buffer handling
- Allow APC to be processed

[MSAFD] APC fixes

Fix New Moon 28 crash on website ads.

[MSAFD] Code cleanup

- Rename APC context data type
- Rename APC functions (Afd* -> Msafd*)
- Add SAL annotation
- BUGBUG -> FIXME
- Use sizeof(*ptr) instead of sizeof(type)
- Fix debug prints

[MSAFD] Add comment for SleepEx hacks

[MSAFD] WSPGetOverlappedResult: Fix failed ws2_32:recv tests

[MSAFD] WSPConnect: Fix IOCTL status handling

[MSAFD] send/receive: Simplify IOCTL status handling

[MSAFD] Get/SetSocketInformation: Simplify IOCTL status handling

Co-Authored-By: Emmanuel Gonzalez <momo16136@tamu.edu>
Co-Authored-By: Julio Carchi Ruiz <juliocarchi@yahoo.com>

* [MSAFD] Fixes
- pass overlapped structure pointer to apc context when there no completion routine. .NET expects this behavior
- Fix WSPGetOverlappedResult to return flags and status code upon failure
---------
Original author: Peter Hater <7element@mail.bg>
Co-authored-by: Thamatip Chitpong <thamatip.chitpong@reactos.org>
Co-authored-by: Emmanuel Gonzalez <momo16136@tamu.edu>
Co-authored-by: Julio Carchi Ruiz <juliocarchi@yahoo.com>
This commit is contained in:
Mikhail Tyukin
2026-01-18 14:42:24 -05:00
committed by GitHub
parent fda2ee2ad8
commit 4e7a64b3ec
6 changed files with 715 additions and 451 deletions

View File

@@ -418,7 +418,7 @@ ok:
return (SOCKET)Sock;
error:
ERR("Ending %x\n", Status);
TRACE("Ending %x\n", Status);
if( SharedData )
{
@@ -610,6 +610,20 @@ TranslateNtStatusError(NTSTATUS Status)
}
}
VOID
MsafdWaitForAlert(
_In_ HANDLE hObject)
{
DWORD Result;
for (;;)
{
Result = WaitForSingleObjectEx(hObject, INFINITE, TRUE);
if (Result != WAIT_IO_COMPLETION)
break;
}
}
/*
* FUNCTION: Closes an open socket
* ARGUMENTS:
@@ -640,18 +654,8 @@ WSPCloseSocket(IN SOCKET Handle,
return SOCKET_ERROR;
}
/* Create the Wait Event */
Status = NtCreateEvent(&SockEvent,
EVENT_ALL_ACCESS,
NULL,
SynchronizationEvent,
FALSE);
if(!NT_SUCCESS(Status))
{
ERR("NtCreateEvent failed: 0x%08x\n", Status);
return SOCKET_ERROR;
}
/* HACK: Allow APC to be processed */
SleepEx(0, TRUE);
if (Socket->HelperEvents & WSH_NOTIFY_CLOSE)
{
@@ -663,9 +667,8 @@ WSPCloseSocket(IN SOCKET Handle,
if (Status)
{
if (lpErrno) *lpErrno = Status;
ERR("WSHNotify failed. Error 0x%#x\n", Status);
NtClose(SockEvent);
if (lpErrno) *lpErrno = Status;
return SOCKET_ERROR;
}
}
@@ -674,11 +677,22 @@ WSPCloseSocket(IN SOCKET Handle,
if (Socket->SharedData->State == SocketClosed)
{
WARN("Socket is closing.\n");
NtClose(SockEvent);
if (lpErrno) *lpErrno = WSAENOTSOCK;
return SOCKET_ERROR;
}
/* Create the Wait Event */
Status = NtCreateEvent(&SockEvent,
EVENT_ALL_ACCESS,
NULL,
SynchronizationEvent,
FALSE);
if (!NT_SUCCESS(Status))
{
ERR("NtCreateEvent failed: 0x%08x\n", Status);
return SOCKET_ERROR;
}
/* Decrement reference count on SharedData */
References = InterlockedDecrement(&Socket->SharedData->RefCount);
if (References)
@@ -775,7 +789,7 @@ WSPCloseSocket(IN SOCKET Handle,
/* Wait for return */
if (Status == STATUS_PENDING)
{
WaitForSingleObject(SockEvent, INFINITE);
MsafdWaitForAlert(SockEvent);
Status = IoStatusBlock.Status;
}
}
@@ -943,7 +957,7 @@ WSPBind(SOCKET Handle,
/* Wait for return */
if (Status == STATUS_PENDING)
{
WaitForSingleObject(SockEvent, INFINITE);
MsafdWaitForAlert(SockEvent);
Status = IOSB.Status;
}
@@ -1028,7 +1042,7 @@ WSPListen(SOCKET Handle,
/* Wait for return */
if (Status == STATUS_PENDING)
{
WaitForSingleObject(SockEvent, INFINITE);
MsafdWaitForAlert(SockEvent);
Status = IOSB.Status;
}
@@ -1304,7 +1318,7 @@ WSPSelect(IN int nfds,
/* Wait for Completion */
if (Status == STATUS_PENDING)
{
WaitForSingleObject(SockEvent, INFINITE);
MsafdWaitForAlert(SockEvent);
Status = IOSB.Status;
}
@@ -1531,7 +1545,7 @@ WSPAccept(
/* Wait for return */
if (Status == STATUS_PENDING)
{
WaitForSingleObject(SockEvent, INFINITE);
MsafdWaitForAlert(SockEvent);
Status = IOSB.Status;
}
@@ -1564,7 +1578,7 @@ WSPAccept(
/* Wait for return */
if (Status == STATUS_PENDING)
{
WaitForSingleObject(SockEvent, INFINITE);
MsafdWaitForAlert(SockEvent);
Status = IOSB.Status;
}
@@ -1604,7 +1618,7 @@ WSPAccept(
/* Wait for return */
if (Status == STATUS_PENDING)
{
WaitForSingleObject(SockEvent, INFINITE);
MsafdWaitForAlert(SockEvent);
Status = IOSB.Status;
}
@@ -1709,7 +1723,7 @@ WSPAccept(
/* Wait for return */
if (Status == STATUS_PENDING)
{
WaitForSingleObject(SockEvent, INFINITE);
MsafdWaitForAlert(SockEvent);
Status = IOSB.Status;
}
@@ -1763,7 +1777,7 @@ WSPAccept(
/* Wait for return */
if (Status == STATUS_PENDING)
{
WaitForSingleObject(SockEvent, INFINITE);
MsafdWaitForAlert(SockEvent);
Status = IOSB.Status;
}
@@ -1825,6 +1839,60 @@ WSPAccept(
return AcceptSocket;
}
static
VOID
NTAPI
MsafdConnectAPC(
_In_ PVOID ApcContext,
_In_ PIO_STATUS_BLOCK IoStatusBlock,
_In_ ULONG Reserved)
{
PMSAFD_CONNECT_APC_CONTEXT Context = ApcContext;
TRACE("MsafdConnectAPC(%p %lx %lx)\n", ApcContext, IoStatusBlock->Status, IoStatusBlock->Information);
PSOCKET_INFORMATION Socket = GetSocketStructure(Context->lpSocket);
if (!Socket)
{
// FIXME: Socket is closed before this APC could run
HeapFree(GlobalHeap, 0, ApcContext);
return;
}
Socket->SharedData->SocketLastError = TranslateNtStatusError(IoStatusBlock->Status);
if (IoStatusBlock->Status == STATUS_SUCCESS)
{
Socket->SharedData->State = SocketConnected;
Socket->TdiConnectionHandle = (HANDLE)IoStatusBlock->Information;
Socket->SharedData->ConnectTime = GetCurrentTimeInSeconds();
}
/* Re-enable Async Event */
SockReenableAsyncSelectEvent(Socket, FD_WRITE);
/* FIXME: THIS IS NOT RIGHT!!! HACK HACK HACK! */
SockReenableAsyncSelectEvent(Socket, FD_CONNECT);
if (IoStatusBlock->Status == STATUS_SUCCESS && (Socket->HelperEvents & WSH_NOTIFY_CONNECT))
{
Socket->HelperData->WSHNotify(Socket->HelperContext,
Socket->Handle,
Socket->TdiAddressHandle,
Socket->TdiConnectionHandle,
WSH_NOTIFY_CONNECT);
}
else if (IoStatusBlock->Status != STATUS_SUCCESS && (Socket->HelperEvents & WSH_NOTIFY_CONNECT_ERROR))
{
Socket->HelperData->WSHNotify(Socket->HelperContext,
Socket->Handle,
Socket->TdiAddressHandle,
Socket->TdiConnectionHandle,
WSH_NOTIFY_CONNECT_ERROR);
}
HeapFree(GlobalHeap, 0, ApcContext);
}
int
WSPAPI
WSPConnect(SOCKET Handle,
@@ -1836,19 +1904,21 @@ WSPConnect(SOCKET Handle,
LPQOS lpGQOS,
LPINT lpErrno)
{
IO_STATUS_BLOCK IOSB;
PAFD_CONNECT_INFO ConnectInfo = NULL;
PSOCKET_INFORMATION Socket;
NTSTATUS Status;
INT Errno;
ULONG ConnectDataLength;
ULONG InConnectDataLength;
INT BindAddressLength;
PSOCKADDR BindAddress;
HANDLE SockEvent;
int SocketDataLength;
IO_STATUS_BLOCK DummyIOSB;
PIO_STATUS_BLOCK IOSB = &DummyIOSB;
PAFD_CONNECT_INFO ConnectInfo = NULL;
PSOCKET_INFORMATION Socket;
NTSTATUS Status;
INT Errno;
ULONG ConnectDataLength;
ULONG InConnectDataLength;
HANDLE SockEvent;
int SocketDataLength;
PMSAFD_CONNECT_APC_CONTEXT APCContext = NULL;
PIO_APC_ROUTINE APCFunction = NULL;
UCHAR Buffer[128];
TRACE("Called (%lx) %lx:%d\n", Handle, ((const struct sockaddr_in *)SocketAddress)->sin_addr, ((const struct sockaddr_in *)SocketAddress)->sin_port);
TRACE("WSPConnect(%x)\n", Handle);
/* Get the Socket Structure associate to this Socket*/
Socket = GetSocketStructure(Handle);
@@ -1858,34 +1928,38 @@ WSPConnect(SOCKET Handle,
return SOCKET_ERROR;
}
Status = NtCreateEvent(&SockEvent,
EVENT_ALL_ACCESS,
NULL,
SynchronizationEvent,
FALSE);
if (!NT_SUCCESS(Status))
return SOCKET_ERROR;
/* Bind us First */
if (Socket->SharedData->State == SocketOpen)
{
INT BindAddressLength;
PSOCKADDR BindAddress;
INT BindError;
/* Get the Wildcard Address */
BindAddressLength = Socket->HelperData->MaxWSAddressLength;
BindAddress = HeapAlloc(GetProcessHeap(), 0, BindAddressLength);
if (!BindAddress)
{
NtClose(SockEvent);
return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL);
}
Socket->HelperData->WSHGetWildcardSockaddr (Socket->HelperContext,
BindAddress,
&BindAddressLength);
Socket->HelperData->WSHGetWildcardSockaddr(Socket->HelperContext,
BindAddress,
&BindAddressLength);
/* Bind it */
if (WSPBind(Handle, BindAddress, BindAddressLength, lpErrno) == SOCKET_ERROR)
BindError = WSPBind(Handle, BindAddress, BindAddressLength, lpErrno);
HeapFree(GetProcessHeap(), 0, BindAddress);
if (BindError == SOCKET_ERROR)
return SOCKET_ERROR;
}
Status = NtCreateEvent(&SockEvent,
EVENT_ALL_ACCESS,
NULL,
SynchronizationEvent,
FALSE);
if (!NT_SUCCESS(Status))
return SOCKET_ERROR;
/* Set the Connect Data */
if (lpCallerData != NULL)
{
@@ -1894,7 +1968,7 @@ WSPConnect(SOCKET Handle,
SockEvent,
NULL,
NULL,
&IOSB,
IOSB,
IOCTL_AFD_SET_CONNECT_DATA,
lpCallerData->buf,
ConnectDataLength,
@@ -1903,25 +1977,25 @@ WSPConnect(SOCKET Handle,
/* Wait for return */
if (Status == STATUS_PENDING)
{
WaitForSingleObject(SockEvent, INFINITE);
Status = IOSB.Status;
MsafdWaitForAlert(SockEvent);
Status = IOSB->Status;
}
if (Status != STATUS_SUCCESS)
goto notify;
goto Leave;
}
/* Calculate the size of SocketAddress->sa_data */
SocketDataLength = SocketAddressLength - FIELD_OFFSET(struct sockaddr, sa_data);
/* Allocate a connection info buffer with SocketDataLength bytes of payload */
ConnectInfo = HeapAlloc(GetProcessHeap(), 0,
FIELD_OFFSET(AFD_CONNECT_INFO,
RemoteAddress.Address[0].Address[SocketDataLength]));
if (!ConnectInfo)
ConnectInfo = (PAFD_CONNECT_INFO)Buffer;
int connectionInfoSize = FIELD_OFFSET(AFD_CONNECT_INFO, RemoteAddress.Address[0].Address[SocketDataLength]);
if (connectionInfoSize > 128)
{
Status = STATUS_INSUFFICIENT_RESOURCES;
goto notify;
*lpErrno = WSAEFAULT;
return SOCKET_ERROR;
}
/* Set up Address in TDI Format */
@@ -1951,7 +2025,7 @@ WSPConnect(SOCKET Handle,
SockEvent,
NULL,
NULL,
&IOSB,
IOSB,
IOCTL_AFD_SET_CONNECT_DATA_SIZE,
&InConnectDataLength,
sizeof(InConnectDataLength),
@@ -1961,12 +2035,12 @@ WSPConnect(SOCKET Handle,
/* Wait for return */
if (Status == STATUS_PENDING)
{
WaitForSingleObject(SockEvent, INFINITE);
Status = IOSB.Status;
MsafdWaitForAlert(SockEvent);
Status = IOSB->Status;
}
if (Status != STATUS_SUCCESS)
goto notify;
goto Leave;
}
/* AFD doesn't seem to care if these are invalid, but let's 0 them anyways */
@@ -1974,46 +2048,76 @@ WSPConnect(SOCKET Handle,
ConnectInfo->UseSAN = FALSE;
ConnectInfo->Unknown = 0;
/* FIXME: Handle Async Connect */
/* Verify if we should use APC */
if (Socket->SharedData->NonBlocking)
{
ERR("Async Connect UNIMPLEMENTED!\n");
APCContext = HeapAlloc(GlobalHeap, 0, sizeof(*APCContext));
if (!APCContext)
{
ERR("Not enough memory for APC Context\n");
Status = STATUS_INSUFFICIENT_RESOURCES;
goto Leave;
}
APCContext->lpSocket = Handle;
APCFunction = &MsafdConnectAPC;
IOSB = &APCContext->IoStatusBlock;
}
IOSB->Status = STATUS_PENDING;
IOSB->Information = 0;
/* Send IOCTL */
Status = NtDeviceIoControlFile((HANDLE)Handle,
SockEvent,
NULL,
NULL,
&IOSB,
APCFunction ? NULL : SockEvent,
APCFunction,
APCContext,
IOSB,
IOCTL_AFD_CONNECT,
ConnectInfo,
0x22,
connectionInfoSize,
NULL,
0);
/* Wait for return */
if (Status == STATUS_PENDING)
if (Socket->SharedData->NonBlocking)
{
WaitForSingleObject(SockEvent, INFINITE);
Status = IOSB.Status;
if (Status == STATUS_PENDING)
{
Status = STATUS_CANT_WAIT; // WSAEWOULDBLOCK
goto Leave;
}
else
{
/* HACK: Allow APC to be processed */
SleepEx(0, TRUE);
}
}
else
{
/* Wait for completion if blocking */
if (Status == STATUS_PENDING)
{
MsafdWaitForAlert(SockEvent);
Status = IOSB->Status;
}
Socket->SharedData->SocketLastError = TranslateNtStatusError(Status);
if (Status != STATUS_SUCCESS)
goto Leave;
Socket->SharedData->State = SocketConnected;
Socket->TdiConnectionHandle = (HANDLE)IOSB->Information;
Socket->SharedData->ConnectTime = GetCurrentTimeInSeconds();
}
Socket->SharedData->SocketLastError = TranslateNtStatusError(Status);
if (Status != STATUS_SUCCESS)
goto notify;
Socket->SharedData->State = SocketConnected;
Socket->TdiConnectionHandle = (HANDLE)IOSB.Information;
Socket->SharedData->ConnectTime = GetCurrentTimeInSeconds();
/* Get any pending connect data */
if (lpCalleeData != NULL)
if (lpCalleeData != NULL && Status == STATUS_SUCCESS)
{
IOSB = &DummyIOSB;
Status = NtDeviceIoControlFile((HANDLE)Handle,
SockEvent,
NULL,
NULL,
&IOSB,
IOSB,
IOCTL_AFD_GET_CONNECT_DATA,
NULL,
0,
@@ -2022,15 +2126,15 @@ WSPConnect(SOCKET Handle,
/* Wait for return */
if (Status == STATUS_PENDING)
{
WaitForSingleObject(SockEvent, INFINITE);
Status = IOSB.Status;
MsafdWaitForAlert(SockEvent);
Status = IOSB->Status;
}
}
TRACE("Ending %lx\n", IOSB.Status);
Leave:
TRACE("Ending %lx\n", Status);
notify:
if (ConnectInfo) HeapFree(GetProcessHeap(), 0, ConnectInfo);
NtClose(SockEvent);
/* Re-enable Async Event */
SockReenableAsyncSelectEvent(Socket, FD_WRITE);
@@ -2038,8 +2142,6 @@ notify:
/* FIXME: THIS IS NOT RIGHT!!! HACK HACK HACK! */
SockReenableAsyncSelectEvent(Socket, FD_CONNECT);
NtClose(SockEvent);
if (Status == STATUS_SUCCESS && (Socket->HelperEvents & WSH_NOTIFY_CONNECT))
{
Errno = Socket->HelperData->WSHNotify(Socket->HelperContext,
@@ -2138,7 +2240,7 @@ WSPShutdown(SOCKET Handle,
/* Wait for return */
if (Status == STATUS_PENDING)
{
WaitForSingleObject(SockEvent, INFINITE);
MsafdWaitForAlert(SockEvent);
Status = IOSB.Status;
}
@@ -2218,7 +2320,7 @@ WSPGetSockName(IN SOCKET Handle,
/* Wait for return */
if (Status == STATUS_PENDING)
{
WaitForSingleObject(SockEvent, INFINITE);
MsafdWaitForAlert(SockEvent);
Status = IOSB.Status;
}
@@ -2322,7 +2424,7 @@ WSPGetPeerName(IN SOCKET s,
/* Wait for return */
if (Status == STATUS_PENDING)
{
WaitForSingleObject(SockEvent, INFINITE);
MsafdWaitForAlert(SockEvent);
Status = IOSB.Status;
}
@@ -3254,15 +3356,24 @@ WSPCleanup(OUT LPINT lpErrno)
return 0;
}
static
VOID
NTAPI
AfdInfoAPC(PVOID ApcContext,
PIO_STATUS_BLOCK IoStatusBlock,
ULONG Reserved)
MsafdInfoAPC(
_In_ PVOID ApcContext,
_In_ PIO_STATUS_BLOCK IoStatusBlock,
_In_ ULONG Reserved)
{
PAFDAPCCONTEXT Context = ApcContext;
PMSAFD_INFO_APC_CONTEXT Context = ApcContext;
TRACE("MsafdInfoAPC(%p %lx %lx)\n", ApcContext, IoStatusBlock->Status, IoStatusBlock->Information);
if (Context->lpCompletionRoutine)
Context->lpCompletionRoutine(IoStatusBlock->Status, IoStatusBlock->Information, Context->lpOverlapped, 0);
/* Free IOCTL buffer */
HeapFree(GlobalHeap, 0, Context->lpInfoData);
Context->lpCompletionRoutine(IoStatusBlock->Status, IoStatusBlock->Information, Context->lpOverlapped, 0);
HeapFree(GlobalHeap, 0, ApcContext);
}
@@ -3275,29 +3386,35 @@ GetSocketInformation(PSOCKET_INFORMATION Socket,
LPWSAOVERLAPPED Overlapped OPTIONAL,
LPWSAOVERLAPPED_COMPLETION_ROUTINE CompletionRoutine OPTIONAL)
{
PIO_STATUS_BLOCK IOSB;
IO_STATUS_BLOCK DummyIOSB;
AFD_INFO InfoData;
NTSTATUS Status;
PAFDAPCCONTEXT APCContext;
PIO_APC_ROUTINE APCFunction;
HANDLE Event = NULL;
HANDLE SockEvent;
PIO_STATUS_BLOCK IOSB;
IO_STATUS_BLOCK DummyIOSB;
PAFD_INFO InfoData;
NTSTATUS Status;
PMSAFD_INFO_APC_CONTEXT APCContext;
PIO_APC_ROUTINE APCFunction;
HANDLE Event = NULL;
HANDLE SockEvent;
InfoData = HeapAlloc(GlobalHeap, 0, sizeof(*InfoData));
if (!InfoData)
return SOCKET_ERROR;
Status = NtCreateEvent(&SockEvent,
EVENT_ALL_ACCESS,
NULL,
SynchronizationEvent,
FALSE);
if( !NT_SUCCESS(Status) )
if (!NT_SUCCESS(Status))
{
HeapFree(GlobalHeap, 0, InfoData);
return SOCKET_ERROR;
}
/* Set Info Class */
InfoData.InformationClass = AfdInformationClass;
InfoData->InformationClass = AfdInformationClass;
/* Verify if we should use APC */
if (Overlapped == NULL)
if (!Overlapped)
{
/* Not using Overlapped structure, so use normal blocking on event */
APCContext = NULL;
@@ -3311,31 +3428,26 @@ GetSocketInformation(PSOCKET_INFORMATION Socket,
if ((Socket->SharedData->CreateFlags & SO_SYNCHRONOUS_NONALERT) != 0)
{
TRACE("Opened without flag WSA_FLAG_OVERLAPPED. Do nothing.\n");
NtClose( SockEvent );
return 0;
NtClose(SockEvent);
HeapFree(GlobalHeap, 0, InfoData);
return MsafdReturnWithErrno(STATUS_SUCCESS, NULL, 0, NULL);
}
if (CompletionRoutine == NULL)
APCContext = HeapAlloc(GlobalHeap, 0, sizeof(*APCContext));
if (!APCContext)
{
/* Using Overlapped Structure, but no Completition Routine, so no need for APC */
APCContext = (PAFDAPCCONTEXT)Overlapped;
APCFunction = NULL;
ERR("Not enough memory for APC Context\n");
NtClose(SockEvent);
HeapFree(GlobalHeap, 0, InfoData);
return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, NULL, 0, NULL);
}
APCContext->lpCompletionRoutine = CompletionRoutine;
APCContext->lpOverlapped = Overlapped;
APCContext->lpInfoData = InfoData;
APCFunction = &MsafdInfoAPC;
if (!CompletionRoutine)
Event = Overlapped->hEvent;
}
else
{
/* Using Overlapped Structure and a Completition Routine, so use an APC */
APCFunction = &AfdInfoAPC; // should be a private io completition function inside us
APCContext = HeapAlloc(GlobalHeap, 0, sizeof(AFDAPCCONTEXT));
if (!APCContext)
{
ERR("Not enough memory for APC Context\n");
NtClose( SockEvent );
return WSAEFAULT;
}
APCContext->lpCompletionRoutine = CompletionRoutine;
APCContext->lpOverlapped = Overlapped;
APCContext->lpSocket = Socket;
}
IOSB = (PIO_STATUS_BLOCK)&Overlapped->Internal;
}
@@ -3349,47 +3461,45 @@ GetSocketInformation(PSOCKET_INFORMATION Socket,
APCContext,
IOSB,
IOCTL_AFD_GET_INFO,
&InfoData,
sizeof(InfoData),
&InfoData,
sizeof(InfoData));
InfoData,
sizeof(*InfoData),
InfoData,
sizeof(*InfoData));
/* Wait for return */
if (Status == STATUS_PENDING && Overlapped == NULL)
/* Wait for completion if not overlapped */
if (!Overlapped && Status == STATUS_PENDING)
{
WaitForSingleObject(SockEvent, INFINITE);
MsafdWaitForAlert(SockEvent);
Status = IOSB->Status;
}
NtClose( SockEvent );
TRACE("Status %lx\n", Status);
TRACE("Status %x Information %d\n", Status, IOSB->Information);
if (Status == STATUS_PENDING)
if (Status == STATUS_SUCCESS)
{
TRACE("Leaving (Pending)\n");
return WSA_IO_PENDING;
/* Return Information */
if (Ulong != NULL)
{
*Ulong = InfoData->Information.Ulong;
}
if (LargeInteger != NULL)
{
*LargeInteger = InfoData->Information.LargeInteger;
}
if (Boolean != NULL)
{
*Boolean = InfoData->Information.Boolean;
}
}
if (Status != STATUS_SUCCESS)
return SOCKET_ERROR;
/* Return Information */
if (Ulong != NULL)
NtClose(SockEvent);
if (!APCFunction)
{
*Ulong = InfoData.Information.Ulong;
}
if (LargeInteger != NULL)
{
*LargeInteger = InfoData.Information.LargeInteger;
}
if (Boolean != NULL)
{
*Boolean = InfoData.Information.Boolean;
/* When using APC, this will be freed by the APC function */
HeapFree(GlobalHeap, 0, InfoData);
}
return NO_ERROR;
return MsafdReturnWithErrno(Status, NULL, 0, NULL);
}
@@ -3402,43 +3512,49 @@ SetSocketInformation(PSOCKET_INFORMATION Socket,
LPWSAOVERLAPPED Overlapped OPTIONAL,
LPWSAOVERLAPPED_COMPLETION_ROUTINE CompletionRoutine OPTIONAL)
{
PIO_STATUS_BLOCK IOSB;
IO_STATUS_BLOCK DummyIOSB;
AFD_INFO InfoData;
NTSTATUS Status;
PAFDAPCCONTEXT APCContext;
PIO_APC_ROUTINE APCFunction;
HANDLE Event = NULL;
HANDLE SockEvent;
PIO_STATUS_BLOCK IOSB;
IO_STATUS_BLOCK DummyIOSB;
PAFD_INFO InfoData;
NTSTATUS Status;
PMSAFD_INFO_APC_CONTEXT APCContext;
PIO_APC_ROUTINE APCFunction;
HANDLE Event = NULL;
HANDLE SockEvent;
InfoData = HeapAlloc(GlobalHeap, 0, sizeof(*InfoData));
if (!InfoData)
return SOCKET_ERROR;
Status = NtCreateEvent(&SockEvent,
EVENT_ALL_ACCESS,
NULL,
SynchronizationEvent,
FALSE);
if( !NT_SUCCESS(Status) )
if (!NT_SUCCESS(Status))
{
HeapFree(GlobalHeap, 0, InfoData);
return SOCKET_ERROR;
}
/* Set Info Class */
InfoData.InformationClass = AfdInformationClass;
InfoData->InformationClass = AfdInformationClass;
/* Set Information */
if (Ulong != NULL)
{
InfoData.Information.Ulong = *Ulong;
InfoData->Information.Ulong = *Ulong;
}
if (LargeInteger != NULL)
{
InfoData.Information.LargeInteger = *LargeInteger;
InfoData->Information.LargeInteger = *LargeInteger;
}
if (Boolean != NULL)
{
InfoData.Information.Boolean = *Boolean;
InfoData->Information.Boolean = *Boolean;
}
/* Verify if we should use APC */
if (Overlapped == NULL)
if (!Overlapped)
{
/* Not using Overlapped structure, so use normal blocking on event */
APCContext = NULL;
@@ -3452,31 +3568,26 @@ SetSocketInformation(PSOCKET_INFORMATION Socket,
if ((Socket->SharedData->CreateFlags & SO_SYNCHRONOUS_NONALERT) != 0)
{
TRACE("Opened without flag WSA_FLAG_OVERLAPPED. Do nothing.\n");
NtClose( SockEvent );
return 0;
NtClose(SockEvent);
HeapFree(GlobalHeap, 0, InfoData);
return MsafdReturnWithErrno(STATUS_SUCCESS, NULL, 0, NULL);
}
if (CompletionRoutine == NULL)
APCContext = HeapAlloc(GlobalHeap, 0, sizeof(*APCContext));
if (!APCContext)
{
/* Using Overlapped Structure, but no Completition Routine, so no need for APC */
APCContext = (PAFDAPCCONTEXT)Overlapped;
APCFunction = NULL;
ERR("Not enough memory for APC Context\n");
NtClose(SockEvent);
HeapFree(GlobalHeap, 0, InfoData);
return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, NULL, 0, NULL);
}
APCContext->lpCompletionRoutine = CompletionRoutine;
APCContext->lpOverlapped = Overlapped;
APCContext->lpInfoData = InfoData;
APCFunction = &MsafdInfoAPC;
if (!CompletionRoutine)
Event = Overlapped->hEvent;
}
else
{
/* Using Overlapped Structure and a Completition Routine, so use an APC */
APCFunction = &AfdInfoAPC; // should be a private io completition function inside us
APCContext = HeapAlloc(GlobalHeap, 0, sizeof(AFDAPCCONTEXT));
if (!APCContext)
{
ERR("Not enough memory for APC Context\n");
NtClose( SockEvent );
return WSAEFAULT;
}
APCContext->lpCompletionRoutine = CompletionRoutine;
APCContext->lpOverlapped = Overlapped;
APCContext->lpSocket = Socket;
}
IOSB = (PIO_STATUS_BLOCK)&Overlapped->Internal;
}
@@ -3490,30 +3601,28 @@ SetSocketInformation(PSOCKET_INFORMATION Socket,
APCContext,
IOSB,
IOCTL_AFD_SET_INFO,
&InfoData,
sizeof(InfoData),
InfoData,
sizeof(*InfoData),
NULL,
0);
/* Wait for return */
if (Status == STATUS_PENDING && Overlapped == NULL)
/* Wait for completion if not overlapped */
if (!Overlapped && Status == STATUS_PENDING)
{
WaitForSingleObject(SockEvent, INFINITE);
MsafdWaitForAlert(SockEvent);
Status = IOSB->Status;
}
NtClose( SockEvent );
TRACE("Status %lx\n", Status);
TRACE("Status %x Information %d\n", Status, IOSB->Information);
if (Status == STATUS_PENDING)
NtClose(SockEvent);
if (!APCFunction)
{
TRACE("Leaving (Pending)\n");
return WSA_IO_PENDING;
/* When using APC, this will be freed by the APC function */
HeapFree(GlobalHeap, 0, InfoData);
}
return Status == STATUS_SUCCESS ? NO_ERROR : SOCKET_ERROR;
return MsafdReturnWithErrno(Status, NULL, 0, NULL);
}
PSOCKET_INFORMATION
@@ -3581,7 +3690,7 @@ int CreateContext(PSOCKET_INFORMATION Socket)
/* Wait for Completion */
if (Status == STATUS_PENDING)
{
WaitForSingleObject(SockEvent, INFINITE);
MsafdWaitForAlert(SockEvent);
Status = IOSB.Status;
}

View File

@@ -106,7 +106,7 @@ WSPEventSelect(
/* Wait for return */
if (Status == STATUS_PENDING) {
WaitForSingleObject(SockEvent, INFINITE);
MsafdWaitForAlert(SockEvent);
Status = IOSB.Status;
}
@@ -181,7 +181,7 @@ WSPEnumNetworkEvents(
/* Wait for return */
if (Status == STATUS_PENDING) {
WaitForSingleObject(SockEvent, INFINITE);
MsafdWaitForAlert(SockEvent);
Status = IOSB.Status;
}

View File

@@ -102,6 +102,7 @@ WSPGetOverlappedResult(
OUT LPINT lpErrno)
{
PSOCKET_INFORMATION Socket;
PIO_STATUS_BLOCK IoStatusBlock;
BOOL Ret;
TRACE("Called (%x)\n", Handle);
@@ -120,11 +121,22 @@ WSPGetOverlappedResult(
*lpErrno = WSAEFAULT;
return FALSE;
}
IoStatusBlock = (PIO_STATUS_BLOCK)lpOverlapped;
Ret = GetOverlappedResult((HANDLE)Handle, lpOverlapped, lpdwBytes, fWait);
/* HACK: Allow APC to be processed */
SleepEx(0, TRUE);
if (!fWait && IoStatusBlock->Status == STATUS_PENDING)
{
if (lpErrno)
*lpErrno = WSA_IO_INCOMPLETE;
return FALSE;
}
if (Ret)
{
*lpdwFlags = 0;
/* Re-enable Async Event */
SockReenableAsyncSelectEvent(Socket, FD_OOB);
@@ -132,23 +144,75 @@ WSPGetOverlappedResult(
SockReenableAsyncSelectEvent(Socket, FD_READ);
}
switch (IoStatusBlock->Status)
{
case STATUS_RECEIVE_EXPEDITED:
*lpdwFlags = MSG_OOB;
break;
case STATUS_RECEIVE_PARTIAL_EXPEDITED:
*lpdwFlags = MSG_PARTIAL | MSG_OOB;
break;
case STATUS_RECEIVE_PARTIAL:
*lpdwFlags = MSG_PARTIAL;
break;
default:
*lpdwFlags = 0;
if (lpErrno)
*lpErrno = TranslateNtStatusError(IoStatusBlock->Status);
break;
}
return Ret;
}
static
VOID
NTAPI
AfdAPC(PVOID ApcContext,
PIO_STATUS_BLOCK IoStatusBlock,
ULONG Reserved)
MsafdRecvAPC(
_In_ PVOID ApcContext,
_In_ PIO_STATUS_BLOCK IoStatusBlock,
_In_ ULONG Reserved)
{
PAFDAPCCONTEXT Context = ApcContext;
PMSAFD_RECV_APC_CONTEXT Context = ApcContext;
TRACE("MsafdRecvAPC(%p %lx %lx)\n", ApcContext, IoStatusBlock->Status, IoStatusBlock->Information);
/* Re-enable Async Event */
if (IoStatusBlock->Status == STATUS_RECEIVE_EXPEDITED ||
IoStatusBlock->Status == STATUS_RECEIVE_PARTIAL_EXPEDITED)
{
SockReenableAsyncSelectEvent(Context->lpSocket, FD_OOB);
}
else
{
SockReenableAsyncSelectEvent(Context->lpSocket, FD_READ);
}
if (Context->lpCompletionRoutine)
Context->lpCompletionRoutine(IoStatusBlock->Status, IoStatusBlock->Information, (LPWSAOVERLAPPED)IoStatusBlock, 0);
HeapFree(GlobalHeap, 0, ApcContext);
}
static
VOID
NTAPI
MsafdSendAPC(
_In_ PVOID ApcContext,
_In_ PIO_STATUS_BLOCK IoStatusBlock,
_In_ ULONG Reserved)
{
PMSAFD_SEND_APC_CONTEXT Context = ApcContext;
TRACE("MsafdSendAPC(%p %lx %lx)\n", ApcContext, IoStatusBlock->Status, IoStatusBlock->Information);
/* Re-enable Async Event */
SockReenableAsyncSelectEvent(Context->lpSocket, FD_OOB);
SockReenableAsyncSelectEvent(Context->lpSocket, FD_READ);
SockReenableAsyncSelectEvent(Context->lpSocket, FD_WRITE);
Context->lpCompletionRoutine(IoStatusBlock->Status, IoStatusBlock->Information, Context->lpOverlapped, 0);
if (Context->lpCompletionRoutine)
Context->lpCompletionRoutine(IoStatusBlock->Status, IoStatusBlock->Information, (LPWSAOVERLAPPED)IoStatusBlock, 0);
/* Free IOCTL buffers */
HeapFree(GlobalHeap, 0, ApcContext);
}
@@ -168,13 +232,14 @@ WSPRecv(SOCKET Handle,
IO_STATUS_BLOCK DummyIOSB;
AFD_RECV_INFO RecvInfo;
NTSTATUS Status;
PVOID APCContext;
PIO_APC_ROUTINE APCFunction;
PMSAFD_RECV_APC_CONTEXT APCContext;
PIO_APC_ROUTINE APCFunction = NULL;
HANDLE Event = NULL;
HANDLE SockEvent;
PSOCKET_INFORMATION Socket;
DWORD NumberOfBytesRead;
TRACE("Called (%x)\n", Handle);
TRACE("WSPRecv(%x)\n", Handle);
/* Get the Socket Structure associate to this Socket*/
Socket = GetSocketStructure(Handle);
@@ -197,11 +262,15 @@ WSPRecv(SOCKET Handle,
return SOCKET_ERROR;
}
memset(&RecvInfo, 0, sizeof(RecvInfo));
Status = NtCreateEvent( &SockEvent, EVENT_ALL_ACCESS,
NULL, SynchronizationEvent, FALSE );
if( !NT_SUCCESS(Status) )
return -1;
if (!NT_SUCCESS(Status))
{
return SOCKET_ERROR;
}
/* Set up the Receive Structure */
RecvInfo.BufferArray = (PAFD_WSABUF)lpBuffers;
@@ -210,7 +279,7 @@ WSPRecv(SOCKET Handle,
RecvInfo.AfdFlags = Socket->SharedData->NonBlocking ? AFD_IMMEDIATE : 0;
/* Set the TDI Flags */
if (*ReceiveFlags == 0)
if (!ReceiveFlags || *ReceiveFlags == 0)
{
RecvInfo.TdiFlags |= TDI_RECEIVE_NORMAL;
}
@@ -233,8 +302,7 @@ WSPRecv(SOCKET Handle,
}
/* Verify if we should use APC */
if (lpOverlapped == NULL)
if (!lpOverlapped)
{
/* Not using Overlapped structure, so use normal blocking on event */
APCContext = NULL;
@@ -248,35 +316,48 @@ WSPRecv(SOCKET Handle,
if ((Socket->SharedData->CreateFlags & SO_SYNCHRONOUS_NONALERT) != 0)
{
TRACE("Opened without flag WSA_FLAG_OVERLAPPED. Do nothing.\n");
return MsafdReturnWithErrno(0, lpErrno, 0, lpNumberOfBytesRead);
NtClose(SockEvent);
return MsafdReturnWithErrno(STATUS_SUCCESS, lpErrno, 0, lpNumberOfBytesRead);
}
if (lpCompletionRoutine == NULL)
IOSB = (PIO_STATUS_BLOCK)lpOverlapped;
if (lpCompletionRoutine)
{
/* Using Overlapped Structure, but no Completion Routine, so no need for APC */
APCContext = lpOverlapped;
APCFunction = NULL;
Event = lpOverlapped->hEvent;
}
else
{
/* Using Overlapped Structure and a Completion Routine, so use an APC */
APCFunction = &AfdAPC; // should be a private io completion function inside us
APCContext = HeapAlloc(GlobalHeap, 0, sizeof(AFDAPCCONTEXT));
RecvInfo.AfdFlags |= AFD_SKIP_FIO;
APCFunction = &MsafdRecvAPC;
APCContext = HeapAlloc(GlobalHeap, 0, sizeof(*APCContext));
if (!APCContext)
{
ERR("Not enough memory for APC Context\n");
NtClose(SockEvent);
return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, lpNumberOfBytesRead);
}
((PAFDAPCCONTEXT)APCContext)->lpCompletionRoutine = lpCompletionRoutine;
((PAFDAPCCONTEXT)APCContext)->lpOverlapped = lpOverlapped;
((PAFDAPCCONTEXT)APCContext)->lpSocket = Socket;
RecvInfo.AfdFlags |= AFD_SKIP_FIO;
APCContext->lpCompletionRoutine = lpCompletionRoutine;
APCContext->lpOverlapped = lpOverlapped;
APCContext->lpSocket = Socket;
Event = NULL;
}
else
{
// .NET expects APCContext to be a pointer to lpOverlapped, because it adds some additional metadata past OVERLAPPED structure
// Do not use APC when there is no completion context
Event = lpOverlapped->hEvent;
APCFunction = NULL;
if (Event)
{
APCContext = NULL;
}
else
{
APCContext = (PMSAFD_RECV_APC_CONTEXT)lpOverlapped;
}
}
IOSB = (PIO_STATUS_BLOCK)&lpOverlapped->Internal;
RecvInfo.AfdFlags |= AFD_OVERLAPPED;
}
IOSB->Information = 0;
IOSB->Status = STATUS_PENDING;
/* Send IOCTL */
@@ -291,51 +372,47 @@ WSPRecv(SOCKET Handle,
NULL,
0);
/* Non-blocking sockets must wait until data is available */
if (Status == STATUS_PENDING && Socket->SharedData->NonBlocking)
/* Wait for completion if not overlapped */
if (!lpOverlapped && Status == STATUS_PENDING)
{
if (lpErrno) *lpErrno = WSAEWOULDBLOCK;
return SOCKET_ERROR;
}
/* Wait for completion of not overlapped */
if (Status == STATUS_PENDING && lpOverlapped == NULL)
{
/* It's up to the protocol to time out recv. We must wait
* until the protocol decides it's had enough.
*/
WaitForSingleObject(SockEvent, INFINITE);
/* It's up to the protocol to time out recv. We must wait
* until the protocol decides it's had enough. */
MsafdWaitForAlert(SockEvent);
Status = IOSB->Status;
}
NtClose( SockEvent );
TRACE("Status %x Information %d\n", Status, IOSB->Information);
NtClose(SockEvent);
if (Status == STATUS_PENDING)
{
TRACE("Leaving (Pending)\n");
return MsafdReturnWithErrno(Status, lpErrno, IOSB->Information, lpNumberOfBytesRead);
return MsafdReturnWithErrno(Status, lpErrno, 0, NULL);
}
/* Return the Flags */
*ReceiveFlags = 0;
NumberOfBytesRead = (DWORD)IOSB->Information;
switch (Status)
/* Return the Flags */
if (ReceiveFlags)
{
case STATUS_RECEIVE_EXPEDITED:
*ReceiveFlags = MSG_OOB;
break;
case STATUS_RECEIVE_PARTIAL_EXPEDITED:
*ReceiveFlags = MSG_PARTIAL | MSG_OOB;
break;
case STATUS_RECEIVE_PARTIAL:
*ReceiveFlags = MSG_PARTIAL;
break;
switch (Status)
{
case STATUS_RECEIVE_EXPEDITED:
*ReceiveFlags = MSG_OOB;
break;
case STATUS_RECEIVE_PARTIAL_EXPEDITED:
*ReceiveFlags = MSG_PARTIAL | MSG_OOB;
break;
case STATUS_RECEIVE_PARTIAL:
*ReceiveFlags = MSG_PARTIAL;
break;
default:
*ReceiveFlags = 0;
break;
}
}
/* Re-enable Async Event */
if (*ReceiveFlags & MSG_OOB)
if (ReceiveFlags && (*ReceiveFlags & MSG_OOB))
{
SockReenableAsyncSelectEvent(Socket, FD_OOB);
}
@@ -344,13 +421,8 @@ WSPRecv(SOCKET Handle,
SockReenableAsyncSelectEvent(Socket, FD_READ);
}
if (Status == STATUS_SUCCESS && lpOverlapped && lpCompletionRoutine)
{
lpCompletionRoutine(Status, IOSB->Information, lpOverlapped, *ReceiveFlags);
HeapFree(GlobalHeap, 0, (PVOID)APCContext);
}
return MsafdReturnWithErrno ( Status, lpErrno, IOSB->Information, lpNumberOfBytesRead );
TRACE("Leaving (%lx %ld)\n", Status, NumberOfBytesRead);
return MsafdReturnWithErrno(Status, lpErrno, NumberOfBytesRead, lpNumberOfBytesRead);
}
int
@@ -371,11 +443,14 @@ WSPRecvFrom(SOCKET Handle,
IO_STATUS_BLOCK DummyIOSB;
AFD_RECV_INFO_UDP RecvInfo;
NTSTATUS Status;
PVOID APCContext;
PVOID APCFunction;
PMSAFD_RECV_APC_CONTEXT APCContext;
PIO_APC_ROUTINE APCFunction = NULL;
HANDLE Event = NULL;
HANDLE SockEvent;
PSOCKET_INFORMATION Socket;
DWORD NumberOfBytesRead;
TRACE("WSPRecvFrom(%x)\n", Handle);
/* Get the Socket Structure associate to this Socket*/
Socket = GetSocketStructure(Handle);
@@ -422,12 +497,15 @@ WSPRecvFrom(SOCKET Handle,
if (WSPBind(Handle, SocketAddress, *SocketAddressLength, lpErrno) == SOCKET_ERROR)
return SOCKET_ERROR;
}
memset(&RecvInfo, 0, sizeof(RecvInfo));
Status = NtCreateEvent( &SockEvent, EVENT_ALL_ACCESS,
NULL, SynchronizationEvent, FALSE );
if( !NT_SUCCESS(Status) )
return -1;
{
return SOCKET_ERROR;
}
/* Set up the Receive Structure */
RecvInfo.BufferArray = (PAFD_WSABUF)lpBuffers;
@@ -438,7 +516,7 @@ WSPRecvFrom(SOCKET Handle,
RecvInfo.Address = SocketAddress;
/* Set the TDI Flags */
if (*ReceiveFlags == 0)
if (!ReceiveFlags || *ReceiveFlags == 0)
{
RecvInfo.TdiFlags |= TDI_RECEIVE_NORMAL;
}
@@ -461,8 +539,7 @@ WSPRecvFrom(SOCKET Handle,
}
/* Verify if we should use APC */
if (lpOverlapped == NULL)
if (!lpOverlapped)
{
/* Not using Overlapped structure, so use normal blocking on event */
APCContext = NULL;
@@ -476,35 +553,48 @@ WSPRecvFrom(SOCKET Handle,
if ((Socket->SharedData->CreateFlags & SO_SYNCHRONOUS_NONALERT) != 0)
{
TRACE("Opened without flag WSA_FLAG_OVERLAPPED. Do nothing.\n");
return MsafdReturnWithErrno(0, lpErrno, 0, lpNumberOfBytesRead);
NtClose(SockEvent);
return MsafdReturnWithErrno(STATUS_SUCCESS, lpErrno, 0, lpNumberOfBytesRead);
}
if (lpCompletionRoutine == NULL)
IOSB = (PIO_STATUS_BLOCK)lpOverlapped;
if (lpCompletionRoutine)
{
/* Using Overlapped Structure, but no Completion Routine, so no need for APC */
APCContext = lpOverlapped;
APCFunction = NULL;
Event = lpOverlapped->hEvent;
}
else
{
/* Using Overlapped Structure and a Completion Routine, so use an APC */
APCFunction = &AfdAPC; // should be a private io completion function inside us
APCContext = HeapAlloc(GlobalHeap, 0, sizeof(AFDAPCCONTEXT));
RecvInfo.AfdFlags |= AFD_SKIP_FIO;
APCFunction = &MsafdRecvAPC;
APCContext = HeapAlloc(GlobalHeap, 0, sizeof(*APCContext));
if (!APCContext)
{
ERR("Not enough memory for APC Context\n");
NtClose(SockEvent);
return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, lpNumberOfBytesRead);
}
((PAFDAPCCONTEXT)APCContext)->lpCompletionRoutine = lpCompletionRoutine;
((PAFDAPCCONTEXT)APCContext)->lpOverlapped = lpOverlapped;
((PAFDAPCCONTEXT)APCContext)->lpSocket = Socket;
RecvInfo.AfdFlags |= AFD_SKIP_FIO;
APCContext->lpCompletionRoutine = lpCompletionRoutine;
APCContext->lpOverlapped = lpOverlapped;
APCContext->lpSocket = Socket;
Event = NULL;
}
else
{
// .NET expects APCContext to be a pointer to lpOverlapped, because it adds some additional metadata past OVERLAPPED structure
// Do not use APC when there is no completion context
Event = lpOverlapped->hEvent;
APCFunction = NULL;
if (Event)
{
APCContext = NULL;
}
else
{
APCContext = (PMSAFD_RECV_APC_CONTEXT)lpOverlapped;
}
}
IOSB = (PIO_STATUS_BLOCK)&lpOverlapped->Internal;
RecvInfo.AfdFlags |= AFD_OVERLAPPED;
}
IOSB->Information = 0;
IOSB->Status = STATUS_PENDING;
/* Send IOCTL */
@@ -519,39 +609,47 @@ WSPRecvFrom(SOCKET Handle,
NULL,
0);
/* Wait for completion of not overlapped */
if (Status == STATUS_PENDING && lpOverlapped == NULL)
/* Wait for completion if not overlapped */
if (!lpOverlapped && Status == STATUS_PENDING)
{
WaitForSingleObject(SockEvent, INFINITE); // BUGBUG, shouldn wait infinitely for receive...
/* FIXME: Shouldn't wait infinitely for receive... */
MsafdWaitForAlert(SockEvent);
Status = IOSB->Status;
}
NtClose( SockEvent );
/* Return the Flags */
if (ReceiveFlags)
{
switch (Status)
{
case STATUS_RECEIVE_EXPEDITED:
*ReceiveFlags = MSG_OOB;
break;
case STATUS_RECEIVE_PARTIAL_EXPEDITED:
*ReceiveFlags = MSG_PARTIAL | MSG_OOB;
break;
case STATUS_RECEIVE_PARTIAL:
*ReceiveFlags = MSG_PARTIAL;
break;
default:
*ReceiveFlags = 0;
break;
}
}
NtClose(SockEvent);
if (Status == STATUS_PENDING)
{
TRACE("Leaving (Pending)\n");
return MsafdReturnWithErrno(Status, lpErrno, IOSB->Information, lpNumberOfBytesRead);
*lpErrno = WSA_IO_PENDING;
return MsafdReturnWithErrno(Status, lpErrno, 0, NULL);
}
/* Return the Flags */
*ReceiveFlags = 0;
switch (Status)
{
case STATUS_RECEIVE_EXPEDITED:
*ReceiveFlags = MSG_OOB;
break;
case STATUS_RECEIVE_PARTIAL_EXPEDITED:
*ReceiveFlags = MSG_PARTIAL | MSG_OOB;
break;
case STATUS_RECEIVE_PARTIAL:
*ReceiveFlags = MSG_PARTIAL;
break;
}
NumberOfBytesRead = (DWORD)IOSB->Information;
/* Re-enable Async Event */
if (*ReceiveFlags & MSG_OOB)
if (ReceiveFlags && (*ReceiveFlags & MSG_OOB))
{
SockReenableAsyncSelectEvent(Socket, FD_OOB);
}
@@ -560,13 +658,8 @@ WSPRecvFrom(SOCKET Handle,
SockReenableAsyncSelectEvent(Socket, FD_READ);
}
if (Status == STATUS_SUCCESS && lpOverlapped && lpCompletionRoutine)
{
lpCompletionRoutine(Status, IOSB->Information, lpOverlapped, *ReceiveFlags);
HeapFree(GlobalHeap, 0, (PVOID)APCContext);
}
return MsafdReturnWithErrno ( Status, lpErrno, IOSB->Information, lpNumberOfBytesRead );
TRACE("Leaving (%lx %ld)\n", Status, NumberOfBytesRead);
return MsafdReturnWithErrno(Status, lpErrno, NumberOfBytesRead, lpNumberOfBytesRead);
}
@@ -584,13 +677,16 @@ WSPSend(SOCKET Handle,
{
PIO_STATUS_BLOCK IOSB;
IO_STATUS_BLOCK DummyIOSB;
AFD_SEND_INFO SendInfo;
AFD_SEND_INFO SendInfo;
NTSTATUS Status;
PVOID APCContext;
PVOID APCFunction;
PMSAFD_SEND_APC_CONTEXT APCContext;
PIO_APC_ROUTINE APCFunction = NULL;
HANDLE Event = NULL;
HANDLE SockEvent;
PSOCKET_INFORMATION Socket;
DWORD NumberOfBytesSent;
TRACE("WSPSend(%x)\n", Handle);
/* Get the Socket Structure associate to this Socket*/
Socket = GetSocketStructure(Handle);
@@ -607,13 +703,15 @@ WSPSend(SOCKET Handle,
return SOCKET_ERROR;
}
memset(&SendInfo, 0, sizeof(SendInfo));
Status = NtCreateEvent( &SockEvent, EVENT_ALL_ACCESS,
NULL, SynchronizationEvent, FALSE );
if( !NT_SUCCESS(Status) )
return -1;
TRACE("Called\n");
{
return SOCKET_ERROR;
}
/* Set up the Send Structure */
SendInfo.BufferArray = (PAFD_WSABUF)lpBuffers;
@@ -635,7 +733,7 @@ WSPSend(SOCKET Handle,
}
/* Verify if we should use APC */
if (lpOverlapped == NULL)
if (!lpOverlapped)
{
/* Not using Overlapped structure, so use normal blocking on event */
APCContext = NULL;
@@ -649,35 +747,58 @@ WSPSend(SOCKET Handle,
if ((Socket->SharedData->CreateFlags & SO_SYNCHRONOUS_NONALERT) != 0)
{
TRACE("Opened without flag WSA_FLAG_OVERLAPPED. Do nothing.\n");
return MsafdReturnWithErrno(0, lpErrno, 0, lpNumberOfBytesSent);
NtClose(SockEvent);
return MsafdReturnWithErrno(STATUS_SUCCESS, lpErrno, 0, lpNumberOfBytesSent);
}
if (lpCompletionRoutine == NULL)
IOSB = (PIO_STATUS_BLOCK)lpOverlapped;
if (lpCompletionRoutine)
{
/* Using Overlapped Structure, but no Completion Routine, so no need for APC */
APCContext = lpOverlapped;
APCFunction = NULL;
Event = lpOverlapped->hEvent;
}
else
{
/* Using Overlapped Structure and a Completion Routine, so use an APC */
APCFunction = &AfdAPC; // should be a private io completion function inside us
APCContext = HeapAlloc(GlobalHeap, 0, sizeof(AFDAPCCONTEXT));
SendInfo.AfdFlags |= AFD_SKIP_FIO;
APCFunction = &MsafdSendAPC;
APCContext = HeapAlloc(GlobalHeap, 0, sizeof(*APCContext));
if (!APCContext)
{
ERR("Not enough memory for APC Context\n");
NtClose(SockEvent);
return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, lpNumberOfBytesSent);
}
((PAFDAPCCONTEXT)APCContext)->lpCompletionRoutine = lpCompletionRoutine;
((PAFDAPCCONTEXT)APCContext)->lpOverlapped = lpOverlapped;
((PAFDAPCCONTEXT)APCContext)->lpSocket = Socket;
SendInfo.AfdFlags |= AFD_SKIP_FIO;
APCContext->lpCompletionRoutine = lpCompletionRoutine;
APCContext->lpOverlapped = lpOverlapped;
APCContext->lpSocket = Socket;
APCContext->lpRemoteAddress = NULL;
Event = NULL;
}
else
{
// .NET expects APCContext to be a pointer to lpOverlapped, because it adds some additional metadata past OVERLAPPED structure
// Do not use APC when there is no completion context
Event = lpOverlapped->hEvent;
APCFunction = NULL;
if (Event)
{
APCContext = NULL;
}
else
{
APCContext = (PMSAFD_SEND_APC_CONTEXT)lpOverlapped;
}
}
Event = lpOverlapped->hEvent;
// APCContext is unsupported when using event
if (Event)
{
if (lpCompletionRoutine)
HeapFree(GetProcessHeap(), 0, APCContext);
APCContext = NULL;
}
IOSB = (PIO_STATUS_BLOCK)&lpOverlapped->Internal;
SendInfo.AfdFlags |= AFD_OVERLAPPED;
}
IOSB->Information = 0;
IOSB->Status = STATUS_PENDING;
/* Send IOCTL */
@@ -692,33 +813,29 @@ WSPSend(SOCKET Handle,
NULL,
0);
/* Wait for completion of not overlapped */
if (Status == STATUS_PENDING && lpOverlapped == NULL)
/* Wait for completion if not overlapped */
if (!lpOverlapped && Status == STATUS_PENDING)
{
WaitForSingleObject(SockEvent, INFINITE); // BUGBUG, shouldn wait infinitely for send...
/* FIXME: Shouldn't wait infinitely for send... */
MsafdWaitForAlert(SockEvent);
Status = IOSB->Status;
}
NtClose( SockEvent );
NtClose(SockEvent);
if (Status == STATUS_PENDING)
{
TRACE("Leaving (Pending)\n");
return MsafdReturnWithErrno(Status, lpErrno, IOSB->Information, lpNumberOfBytesSent);
return MsafdReturnWithErrno(Status, lpErrno, 0, NULL);
}
NumberOfBytesSent = (DWORD)IOSB->Information;
/* Re-enable Async Event */
SockReenableAsyncSelectEvent(Socket, FD_WRITE);
TRACE("Leaving (Success, %d)\n", IOSB->Information);
if (Status == STATUS_SUCCESS && lpOverlapped && lpCompletionRoutine)
{
lpCompletionRoutine(Status, IOSB->Information, lpOverlapped, 0);
HeapFree(GlobalHeap, 0, (PVOID)APCContext);
}
return MsafdReturnWithErrno( Status, lpErrno, IOSB->Information, lpNumberOfBytesSent );
TRACE("Leaving (%lx %ld)\n", Status, NumberOfBytesSent);
return MsafdReturnWithErrno(Status, lpErrno, NumberOfBytesSent, lpNumberOfBytesSent);
}
int
@@ -739,14 +856,16 @@ WSPSendTo(SOCKET Handle,
IO_STATUS_BLOCK DummyIOSB;
AFD_SEND_INFO_UDP SendInfo;
NTSTATUS Status;
PVOID APCContext;
PVOID APCFunction;
PMSAFD_SEND_APC_CONTEXT APCContext;
PIO_APC_ROUTINE APCFunction = NULL;
HANDLE Event = NULL;
PTRANSPORT_ADDRESS RemoteAddress;
PSOCKADDR BindAddress = NULL;
INT BindAddressLength;
HANDLE SockEvent;
PSOCKET_INFORMATION Socket;
DWORD NumberOfBytesSent;
UCHAR Buffer[128];
TRACE("WSPSendTo(%x)\n", Handle);
/* Get the Socket Structure associate to this Socket */
Socket = GetSocketStructure(Handle);
@@ -780,6 +899,10 @@ WSPSendTo(SOCKET Handle,
/* Bind us First */
if (Socket->SharedData->State == SocketOpen)
{
INT BindAddressLength;
PSOCKADDR BindAddress;
INT BindError;
/* Get the Wildcard Address */
BindAddressLength = Socket->HelperData->MaxWSAddressLength;
BindAddress = HeapAlloc(GlobalHeap, 0, BindAddressLength);
@@ -793,35 +916,29 @@ WSPSendTo(SOCKET Handle,
BindAddress,
&BindAddressLength);
/* Bind it */
if (WSPBind(Handle, BindAddress, BindAddressLength, lpErrno) == SOCKET_ERROR)
BindError = WSPBind(Handle, BindAddress, BindAddressLength, lpErrno);
HeapFree(GlobalHeap, 0, BindAddress);
if (BindError == SOCKET_ERROR)
return SOCKET_ERROR;
}
RemoteAddress = HeapAlloc(GlobalHeap, 0, 0x6 + SocketAddressLength);
if (!RemoteAddress)
{
if (BindAddress != NULL)
{
HeapFree(GlobalHeap, 0, BindAddress);
}
return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL);
}
Status = NtCreateEvent(&SockEvent,
EVENT_ALL_ACCESS,
NULL, SynchronizationEvent, FALSE);
EVENT_ALL_ACCESS,
NULL, SynchronizationEvent, FALSE);
if (!NT_SUCCESS(Status))
{
HeapFree(GlobalHeap, 0, RemoteAddress);
if (BindAddress != NULL)
{
HeapFree(GlobalHeap, 0, BindAddress);
}
return SOCKET_ERROR;
}
/* Set up Address in TDI Format */
RemoteAddress = (PTRANSPORT_ADDRESS)Buffer;
if (SocketAddressLength > 128 - sizeof(TRANSPORT_ADDRESS)) {
*lpErrno = WSAEFAULT;
return SOCKET_ERROR;
}
/* Fill TDI address */
RemoteAddress->TAAddressCount = 1;
RemoteAddress->Address[0].AddressLength = SocketAddressLength - sizeof(SocketAddress->sa_family);
RtlCopyMemory(&RemoteAddress->Address[0].AddressType, SocketAddress, SocketAddressLength);
@@ -834,7 +951,7 @@ WSPSendTo(SOCKET Handle,
SendInfo.TdiConnection.RemoteAddressLength = Socket->HelperData->MaxTDIAddressLength;
/* Verify if we should use APC */
if (lpOverlapped == NULL)
if (!lpOverlapped)
{
/* Not using Overlapped structure, so use normal blocking on event */
APCContext = NULL;
@@ -848,35 +965,51 @@ WSPSendTo(SOCKET Handle,
if ((Socket->SharedData->CreateFlags & SO_SYNCHRONOUS_NONALERT) != 0)
{
TRACE("Opened without flag WSA_FLAG_OVERLAPPED. Do nothing.\n");
return MsafdReturnWithErrno(0, lpErrno, 0, lpNumberOfBytesSent);
NtClose(SockEvent);
return MsafdReturnWithErrno(STATUS_SUCCESS, lpErrno, 0, lpNumberOfBytesSent);
}
if (lpCompletionRoutine == NULL)
IOSB = (PIO_STATUS_BLOCK)lpOverlapped;
if (lpCompletionRoutine)
{
/* Using Overlapped Structure, but no Completion Routine, so no need for APC */
APCContext = lpOverlapped;
APCFunction = NULL;
Event = lpOverlapped->hEvent;
}
else
{
/* Using Overlapped Structure and a Completion Routine, so use an APC */
APCFunction = &AfdAPC; // should be a private io completion function inside us
APCContext = HeapAlloc(GlobalHeap, 0, sizeof(AFDAPCCONTEXT));
SendInfo.AfdFlags |= AFD_SKIP_FIO;
APCFunction = &MsafdSendAPC;
APCContext = HeapAlloc(GlobalHeap, 0, sizeof(*APCContext));
if (!APCContext)
{
ERR("Not enough memory for APC Context\n");
NtClose(SockEvent);
return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, lpNumberOfBytesSent);
}
((PAFDAPCCONTEXT)APCContext)->lpCompletionRoutine = lpCompletionRoutine;
((PAFDAPCCONTEXT)APCContext)->lpOverlapped = lpOverlapped;
((PAFDAPCCONTEXT)APCContext)->lpSocket = Socket;
SendInfo.AfdFlags |= AFD_SKIP_FIO;
APCContext->lpCompletionRoutine = lpCompletionRoutine;
APCContext->lpOverlapped = lpOverlapped;
APCContext->lpSocket = Socket;
APCContext->lpRemoteAddress = RemoteAddress;
Event = NULL;
}
else
{
// .NET expects APCContext to be a pointer to lpOverlapped, because it adds some additional metadata past OVERLAPPED structure
// Do not use APC when there is no completion context
Event = lpOverlapped->hEvent;
APCFunction = NULL;
if (Event)
{
APCContext = NULL;
}
else
{
APCContext = (PMSAFD_SEND_APC_CONTEXT)lpOverlapped;
}
}
IOSB = (PIO_STATUS_BLOCK)&lpOverlapped->Internal;
SendInfo.AfdFlags |= AFD_OVERLAPPED;
}
IOSB->Information = 0;
IOSB->Status = STATUS_PENDING;
/* Send IOCTL */
Status = NtDeviceIoControlFile((HANDLE)Handle,
Event,
@@ -889,36 +1022,29 @@ WSPSendTo(SOCKET Handle,
NULL,
0);
/* Wait for completion of not overlapped */
if (Status == STATUS_PENDING && lpOverlapped == NULL)
/* Wait for completion if not overlapped */
if (!lpOverlapped && Status == STATUS_PENDING)
{
/* BUGBUG, shouldn't wait infinitely for send... */
WaitForSingleObject(SockEvent, INFINITE);
/* FIXME: Shouldn't wait infinitely for send... */
MsafdWaitForAlert(SockEvent);
Status = IOSB->Status;
}
NtClose(SockEvent);
HeapFree(GlobalHeap, 0, RemoteAddress);
if (BindAddress != NULL)
{
HeapFree(GlobalHeap, 0, BindAddress);
}
if (Status == STATUS_PENDING)
{
TRACE("Leaving (Pending)\n");
return MsafdReturnWithErrno(Status, lpErrno, IOSB->Information, lpNumberOfBytesSent);
return MsafdReturnWithErrno(Status, lpErrno, 0, NULL);
}
NumberOfBytesSent = (DWORD)IOSB->Information;
/* Re-enable Async Event */
SockReenableAsyncSelectEvent(Socket, FD_WRITE);
if (Status == STATUS_SUCCESS && lpOverlapped && lpCompletionRoutine)
{
lpCompletionRoutine(Status, IOSB->Information, lpOverlapped, 0);
HeapFree(GlobalHeap, 0, (PVOID)APCContext);
}
return MsafdReturnWithErrno(Status, lpErrno, IOSB->Information, lpNumberOfBytesSent);
TRACE("Leaving (%lx %ld)\n", Status, NumberOfBytesSent);
return MsafdReturnWithErrno(Status, lpErrno, NumberOfBytesSent, lpNumberOfBytesSent);
}
INT

View File

@@ -133,12 +133,34 @@ typedef struct _ASYNC_DATA {
AFD_POLL_INFO AsyncSelectInfo;
} ASYNC_DATA, *PASYNC_DATA;
typedef struct _AFDAPCCONTEXT
typedef struct _MSAFD_INFO_APC_CONTEXT
{
LPWSAOVERLAPPED lpOverlapped;
LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine;
PAFD_INFO lpInfoData;
} MSAFD_INFO_APC_CONTEXT, *PMSAFD_INFO_APC_CONTEXT;
typedef struct _MSAFD_SEND_APC_CONTEXT
{
LPWSAOVERLAPPED lpOverlapped;
LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine;
PSOCKET_INFORMATION lpSocket;
} AFDAPCCONTEXT, *PAFDAPCCONTEXT;
PTRANSPORT_ADDRESS lpRemoteAddress;
} MSAFD_SEND_APC_CONTEXT, *PMSAFD_SEND_APC_CONTEXT;
typedef struct _MSAFD_RECV_APC_CONTEXT
{
LPWSAOVERLAPPED lpOverlapped;
LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine;
PSOCKET_INFORMATION lpSocket;
} MSAFD_RECV_APC_CONTEXT, *PMSAFD_RECV_APC_CONTEXT;
typedef struct _MSAFD_CONNECT_APC_CONTEXT
{
SOCKET lpSocket;
IO_STATUS_BLOCK IoStatusBlock;
} MSAFD_CONNECT_APC_CONTEXT, *PMSAFD_CONNECT_APC_CONTEXT;
_Must_inspect_result_
SOCKET
@@ -543,10 +565,11 @@ typedef VOID (*PASYNC_COMPLETION_ROUTINE)(PVOID Context, PIO_STATUS_BLOCK IoStat
FORCEINLINE
DWORD
MsafdReturnWithErrno(NTSTATUS Status,
LPINT Errno,
DWORD Received,
LPDWORD ReturnedBytes)
MsafdReturnWithErrno(
_In_ NTSTATUS Status,
_Out_opt_ LPINT Errno,
_In_ DWORD Received,
_Out_opt_ LPDWORD ReturnedBytes)
{
if (Errno)
{
@@ -559,8 +582,6 @@ MsafdReturnWithErrno(NTSTATUS Status,
}
else
{
DbgPrint("%s: Received invalid lpErrno pointer!\n", __FUNCTION__);
if (ReturnedBytes)
*ReturnedBytes = (Status == STATUS_SUCCESS) ? Received : 0;
@@ -568,4 +589,8 @@ MsafdReturnWithErrno(NTSTATUS Status,
}
}
VOID
MsafdWaitForAlert(
_In_ HANDLE hObject);
#endif /* __MSAFD_H */

View File

@@ -379,6 +379,10 @@ AfdConnectedSocketWriteData(PDEVICE_OBJECT DeviceObject, PIRP Irp,
Status = QueueUserModeIrp(FCB, Irp, FUNCTION_SEND);
if (Status == STATUS_PENDING)
{
if (SendReq->BufferCount > 1)
{
AFD_DbgPrint(MIN_TRACE,("WARN: More than one buffer %ld\n", SendReq->BufferCount));
}
Status = TdiSendDatagram(&FCB->SendIrp.InFlightRequest,
FCB->AddressFile.Object,
SendReq->BufferArray[0].buf,
@@ -390,7 +394,7 @@ AfdConnectedSocketWriteData(PDEVICE_OBJECT DeviceObject, PIRP Irp,
{
NT_VERIFY(RemoveHeadList(&FCB->PendingIrpList[FUNCTION_SEND]) == &Irp->Tail.Overlay.ListEntry);
Irp->IoStatus.Status = Status;
Irp->IoStatus.Information = 0;
Irp->IoStatus.Information = ((Status == STATUS_SUCCESS) ? SendReq->BufferArray[0].len : 0);
(void)IoSetCancelRoutine(Irp, NULL);
UnlockBuffers(SendReq->BufferArray, SendReq->BufferCount, FALSE);
UnlockRequest(Irp, IoGetCurrentIrpStackLocation(Irp));
@@ -645,6 +649,10 @@ AfdPacketSocketWriteData(PDEVICE_OBJECT DeviceObject, PIRP Irp,
Status = QueueUserModeIrp(FCB, Irp, FUNCTION_SEND);
if (Status == STATUS_PENDING)
{
if (SendReq->BufferCount > 1)
{
AFD_DbgPrint(MIN_TRACE,("WARN: More than one buffer %ld\n", SendReq->BufferCount));
}
Status = TdiSendDatagram(&FCB->SendIrp.InFlightRequest,
FCB->AddressFile.Object,
SendReq->BufferArray[0].buf,
@@ -656,7 +664,7 @@ AfdPacketSocketWriteData(PDEVICE_OBJECT DeviceObject, PIRP Irp,
{
NT_VERIFY(RemoveHeadList(&FCB->PendingIrpList[FUNCTION_SEND]) == &Irp->Tail.Overlay.ListEntry);
Irp->IoStatus.Status = Status;
Irp->IoStatus.Information = 0;
Irp->IoStatus.Information = ((Status == STATUS_SUCCESS) ? SendReq->BufferArray[0].len : 0);
(void)IoSetCancelRoutine(Irp, NULL);
UnlockBuffers(SendReq->BufferArray, SendReq->BufferCount, FALSE);
UnlockRequest(Irp, IoGetCurrentIrpStackLocation(Irp));

View File

@@ -989,11 +989,9 @@ NTSTATUS TdiSend(
Flags,
BufferLength);
TdiCall(*Irp, DeviceObject, NULL, NULL);
/* Does not block... The MDL is deleted in the receive completion
routine. */
return STATUS_PENDING;
return TdiCall(*Irp, DeviceObject, NULL, NULL);
}
NTSTATUS TdiReceive(
@@ -1256,11 +1254,9 @@ NTSTATUS TdiSendDatagram(
BufferLength,
Addr);
TdiCall(*Irp, DeviceObject, NULL, NULL);
/* Does not block... The MDL is deleted in the send completion
routine. */
return STATUS_PENDING;
return TdiCall(*Irp, DeviceObject, NULL, NULL);
}
NTSTATUS TdiDisconnect(