mirror of
				https://github.com/eledio-devices/thirdparty-AsyncTCPSock.git
				synced 2025-10-31 08:42:38 +01:00 
			
		
		
		
	Protect write buffer queue with a per-client mutex
This prevents a race condicion of the ::add() method modifying the queue and being preempted by the high-priority asyncTcpSock task that in turn flushes the same queue to the socket, therefore modifying the queue.
This commit is contained in:
		| @@ -271,6 +271,7 @@ AsyncClient::AsyncClient(int sockfd) | |||||||
| , _writeSpaceRemaining(TCP_SND_BUF) | , _writeSpaceRemaining(TCP_SND_BUF) | ||||||
| , _conn_state(0) | , _conn_state(0) | ||||||
| { | { | ||||||
|  |     _write_mutex = xSemaphoreCreateMutex(); | ||||||
|     if (sockfd != -1) { |     if (sockfd != -1) { | ||||||
|         int r = fcntl( sockfd, F_SETFL, fcntl( sockfd, F_GETFL, 0 ) | O_NONBLOCK ); |         int r = fcntl( sockfd, F_SETFL, fcntl( sockfd, F_GETFL, 0 ) | O_NONBLOCK ); | ||||||
|  |  | ||||||
| @@ -564,25 +565,23 @@ bool AsyncClient::_sockIsWriteable(void) | |||||||
|     bool activity = false; |     bool activity = false; | ||||||
|     bool hasErr = false; |     bool hasErr = false; | ||||||
|  |  | ||||||
|     //Serial.print("AsyncClient::_sockIsWriteable: "); Serial.println(_socket); |     int sent_errno = 0; | ||||||
|  |     size_t sent_cb_length = 0; | ||||||
|  |     uint32_t sent_cb_delay = 0; | ||||||
|  |  | ||||||
|     // Socket is now writeable. What should we do? |     // Socket is now writeable. What should we do? | ||||||
|     switch (_conn_state) { |     switch (_conn_state) { | ||||||
|     case 2: |     case 2: | ||||||
|     case 3: |     case 3: | ||||||
|         //Serial.println("\tconnect end"); |  | ||||||
|         // Socket has finished connecting. What happened? |         // Socket has finished connecting. What happened? | ||||||
|         len = (socklen_t)sizeof(int); |         len = (socklen_t)sizeof(int); | ||||||
|         res = getsockopt(_socket, SOL_SOCKET, SO_ERROR, &sockerr, &len); |         res = getsockopt(_socket, SOL_SOCKET, SO_ERROR, &sockerr, &len); | ||||||
|         if (res < 0) { |         if (res < 0) { | ||||||
|             //Serial.printf("\terrno=%d\r\n", errno); |  | ||||||
|             _error(errno); |             _error(errno); | ||||||
|         } else if (sockerr != 0) { |         } else if (sockerr != 0) { | ||||||
|             //Serial.printf("\tsockerr=%d\r\n", errno); |  | ||||||
|             _error(sockerr); |             _error(sockerr); | ||||||
|         } else { |         } else { | ||||||
|             // Socket is now fully connected |             // Socket is now fully connected | ||||||
|             //Serial.println("SUCCESS"); |  | ||||||
|             _conn_state = 4; |             _conn_state = 4; | ||||||
|             activity = true; |             activity = true; | ||||||
|             _rx_last_packet = millis(); |             _rx_last_packet = millis(); | ||||||
| @@ -596,29 +595,23 @@ bool AsyncClient::_sockIsWriteable(void) | |||||||
|     case 4: |     case 4: | ||||||
|     default: |     default: | ||||||
|         // Socket can accept some new data... |         // Socket can accept some new data... | ||||||
|         //Serial.printf("\tbefore: remaining %d\r\n", _writeSpaceRemaining); |         xSemaphoreTake(_write_mutex, (TickType_t)portMAX_DELAY); | ||||||
|         if (_writeQueue.size() > 0) { |         if (_writeQueue.size() > 0) { | ||||||
|             //Serial.printf("\tbuffers remaining: %d\r\n", _writeQueue.size()); |  | ||||||
|             if (_writeQueue.front().written < _writeQueue.front().length) { |             if (_writeQueue.front().written < _writeQueue.front().length) { | ||||||
|                 uint8_t * p = _writeQueue.front().data + _writeQueue.front().written; |                 uint8_t * p = _writeQueue.front().data + _writeQueue.front().written; | ||||||
|                 size_t n = _writeQueue.front().length - _writeQueue.front().written; |                 size_t n = _writeQueue.front().length - _writeQueue.front().written; | ||||||
|                 //Serial.printf("\tlwip_write(%p, %d) ... ", p, n); |  | ||||||
|                 errno = 0; ssize_t r = lwip_write(_socket, p, n); |                 errno = 0; ssize_t r = lwip_write(_socket, p, n); | ||||||
|                 //Serial.printf("r=%d errno=%d\r\n", r, errno); |  | ||||||
|  |  | ||||||
|                 if (r >= 0) { |                 if (r >= 0) { | ||||||
|                     // Written some data into the socket |                     // Written some data into the socket | ||||||
|                     _writeQueue.front().written += r; |                     _writeQueue.front().written += r; | ||||||
|                     _writeSpaceRemaining += r; |                     _writeSpaceRemaining += r; | ||||||
|                     activity = true; |                     activity = true; | ||||||
|                     //Serial.printf("\tduring: remaining %d\r\n", _writeSpaceRemaining); |  | ||||||
|                 } else if (errno == EAGAIN || errno == EWOULDBLOCK) { |                 } else if (errno == EAGAIN || errno == EWOULDBLOCK) { | ||||||
|                     // Socket is full again |                     // Socket is full, could not write anything | ||||||
|                     //Serial.println("\tEAGAIN"); |  | ||||||
|                     break;  // NOTE: breaks from switch() |  | ||||||
|                 } else { |                 } else { | ||||||
|                     hasErr = true; |                     hasErr = true; | ||||||
|                     _error(errno); |                     sent_errno = errno; | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|  |  | ||||||
| @@ -626,14 +619,18 @@ bool AsyncClient::_sockIsWriteable(void) | |||||||
|                 // Buffer has been fully written to the socket |                 // Buffer has been fully written to the socket | ||||||
|                 _rx_last_packet = millis(); |                 _rx_last_packet = millis(); | ||||||
|                 if (_writeQueue.front().owned) ::free(_writeQueue.front().data); |                 if (_writeQueue.front().owned) ::free(_writeQueue.front().data); | ||||||
|                 if (_sent_cb) { |                 sent_cb_length = _writeQueue.front().length; | ||||||
|                     _sent_cb(_sent_cb_arg, this, _writeQueue.front().length, (millis() - _writeQueue.front().queued_at)); |                 uint32_t sent_cb_delay = millis() - _writeQueue.front().queued_at; | ||||||
|                 } |  | ||||||
|  |  | ||||||
|                 _writeQueue.pop_front(); |                 _writeQueue.pop_front(); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|         //Serial.printf("\tafter: remaining %d\r\n", _writeSpaceRemaining); |         xSemaphoreGive(_write_mutex); | ||||||
|  |  | ||||||
|  |         if (hasErr) { | ||||||
|  |             _error(sent_errno); | ||||||
|  |         } else if (sent_cb_length > 0 && _sent_cb) { | ||||||
|  |             _sent_cb(_sent_cb_arg, this, sent_cb_length, sent_cb_delay); | ||||||
|  |         } | ||||||
|  |  | ||||||
|         break; |         break; | ||||||
|     } |     } | ||||||
| @@ -643,12 +640,8 @@ bool AsyncClient::_sockIsWriteable(void) | |||||||
|  |  | ||||||
| void AsyncClient::_sockIsReadable(void) | void AsyncClient::_sockIsReadable(void) | ||||||
| { | { | ||||||
|     //Serial.print("AsyncClient::_sockIsReadable: "); Serial.println(_socket); |  | ||||||
|  |  | ||||||
|     _rx_last_packet = millis(); |     _rx_last_packet = millis(); | ||||||
|     //Serial.print("\tlwip_read ... "); |  | ||||||
|     errno = 0; ssize_t r = lwip_read(_socket, _readBuffer, MAX_PAYLOAD_SIZE); |     errno = 0; ssize_t r = lwip_read(_socket, _readBuffer, MAX_PAYLOAD_SIZE); | ||||||
|     //Serial.printf("r=%d errno=%d\r\n", r, errno); |  | ||||||
|     if (r > 0) { |     if (r > 0) { | ||||||
|         if(_recv_cb) { |         if(_recv_cb) { | ||||||
|             _recv_cb(_recv_cb_arg, this, _readBuffer, r); |             _recv_cb(_recv_cb_arg, this, _readBuffer, r); | ||||||
| @@ -658,7 +651,7 @@ void AsyncClient::_sockIsReadable(void) | |||||||
|         _close(); |         _close(); | ||||||
|     } else if (r < 0) { |     } else if (r < 0) { | ||||||
|         if (errno == EAGAIN || errno == EWOULDBLOCK) { |         if (errno == EAGAIN || errno == EWOULDBLOCK) { | ||||||
|             //Serial.println("\tEAGAIN"); |             // Do nothing, will try later | ||||||
|         } else { |         } else { | ||||||
|             _error(errno); |             _error(errno); | ||||||
|         } |         } | ||||||
| @@ -671,22 +664,22 @@ void AsyncClient::_sockPoll(void) | |||||||
|  |  | ||||||
|     uint32_t now = millis(); |     uint32_t now = millis(); | ||||||
|  |  | ||||||
|     //Serial.print("AsyncClient::_sockPoll: "); Serial.println(_socket); |  | ||||||
|  |  | ||||||
|     // ACK Timeout - simulated by write queue staleness |     // ACK Timeout - simulated by write queue staleness | ||||||
|     if (_writeQueue.size() > 0 && !_ack_timeout_signaled && _ack_timeout && (now - _writeQueue.front().queued_at) >= _ack_timeout) { |     xSemaphoreTake(_write_mutex, (TickType_t)portMAX_DELAY); | ||||||
|  |     uint32_t sent_delay = now - _writeQueue.front().queued_at; | ||||||
|  |     if (_writeQueue.size() > 0 && !_ack_timeout_signaled && _ack_timeout && sent_delay >= _ack_timeout) { | ||||||
|         _ack_timeout_signaled = true; |         _ack_timeout_signaled = true; | ||||||
|         //log_w("ack timeout %d", pcb->state); |         //log_w("ack timeout %d", pcb->state); | ||||||
|         //Serial.println("\tACK TIMEOUT"); |         xSemaphoreGive(_write_mutex); | ||||||
|         if(_timeout_cb) |         if(_timeout_cb) | ||||||
|             _timeout_cb(_timeout_cb_arg, this, (now - _writeQueue.front().queued_at)); |             _timeout_cb(_timeout_cb_arg, this, sent_delay); | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|  |     xSemaphoreGive(_write_mutex); | ||||||
|  |  | ||||||
|     // RX Timeout |     // RX Timeout | ||||||
|     if (_rx_since_timeout && (now - _rx_last_packet) >= (_rx_since_timeout * 1000)) { |     if (_rx_since_timeout && (now - _rx_last_packet) >= (_rx_since_timeout * 1000)) { | ||||||
|         //log_w("rx timeout %d", pcb->state); |         //log_w("rx timeout %d", pcb->state); | ||||||
|         //Serial.println("\tRX TIMEOUT"); |  | ||||||
|         _close(); |         _close(); | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
| @@ -753,9 +746,12 @@ size_t AsyncClient::add(const char* data, size_t size, uint8_t apiflags) | |||||||
|     n_entry.written = 0; |     n_entry.written = 0; | ||||||
|     n_entry.queued_at = millis(); |     n_entry.queued_at = millis(); | ||||||
|  |  | ||||||
|  |     xSemaphoreTake(_write_mutex, (TickType_t)portMAX_DELAY); | ||||||
|     _writeQueue.push_back(n_entry); |     _writeQueue.push_back(n_entry); | ||||||
|     _writeSpaceRemaining -= will_send; |     _writeSpaceRemaining -= will_send; | ||||||
|     _ack_timeout_signaled = false; |     _ack_timeout_signaled = false; | ||||||
|  |     xSemaphoreGive(_write_mutex); | ||||||
|  |  | ||||||
|     return will_send; |     return will_send; | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -180,6 +180,7 @@ class AsyncClient : public AsyncSocketBase | |||||||
|     } queued_writebuf; |     } queued_writebuf; | ||||||
|  |  | ||||||
|     // Queue of buffers to write to socket |     // Queue of buffers to write to socket | ||||||
|  |     SemaphoreHandle_t _write_mutex; | ||||||
|     std::deque<queued_writebuf> _writeQueue; |     std::deque<queued_writebuf> _writeQueue; | ||||||
|     bool _ack_timeout_signaled = false; |     bool _ack_timeout_signaled = false; | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user