Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions llamafile/flags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ int FLAG_flash_attn = false;
int FLAG_gpu = 0;
int FLAG_http_ibuf_size = 5 * 1024 * 1024;
int FLAG_http_obuf_size = 1024 * 1024;
int FLAG_http_write_timeout = 60000;
int FLAG_keepalive = 5;
int FLAG_main_gpu = 0;
int FLAG_n_gpu_layers = -1;
Expand Down Expand Up @@ -346,6 +347,13 @@ void llamafile_get_flags(int argc, char **argv) {
continue;
}

if (!strcmp(flag, "--http-write-timeout")) {
if (i == argc)
missing("--http-write-timeout");
FLAG_http_write_timeout = atoi(argv[i++]);
continue;
}

//////////////////////////////////////////////////////////////////////
// sampling flags

Expand Down
1 change: 1 addition & 0 deletions llamafile/llamafile.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ extern int FLAG_gpu;
extern int FLAG_gpu;
extern int FLAG_http_ibuf_size;
extern int FLAG_http_obuf_size;
extern int FLAG_http_write_timeout;
extern int FLAG_keepalive;
extern int FLAG_main_gpu;
extern int FLAG_n_gpu_layers;
Expand Down
43 changes: 34 additions & 9 deletions llamafile/server/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <errno.h>
#include <fcntl.h>
#include <limits.h>
#include <poll.h>
#include <string.h>
#include <string>
#include <sys/stat.h>
Expand Down Expand Up @@ -520,14 +521,38 @@ Client::send_response_finish()
//
// unlike send() this won't fail if binary content is detected.
bool
Client::send_binary(const void* p, size_t n)
{
ssize_t sent;
if ((sent = write(fd_, p, n)) != n) {
if (sent == -1 && errno != EAGAIN && errno != ECONNRESET)
SLOG("write failed %m");
close_connection_ = true;
return false;
Client::send_binary(const void* p, size_t n) {
const char* buf = (const char*)p;
size_t written = 0;
while (written < n) {
ssize_t sent = write(fd_, buf + written, n - written);
if (sent == -1) {
if (errno == EINTR)
continue;
if (errno == EAGAIN || errno == EWOULDBLOCK) {
struct pollfd pfd = { .fd = fd_, .events = POLLOUT };
int ret = poll(&pfd, 1, FLAG_http_write_timeout);
if (ret < 0) {
if (errno == EINTR)
continue;
SLOG("poll failed %m");
close_connection_ = true;
return false;
}
if (ret == 0) {
SLOG("write timed out");
close_connection_ = true;
return false;
}
continue;
}
if (errno != ECONNRESET)
SLOG("write failed %m");
close_connection_ = true;
return false;
}
// sent ≥ 0
written += sent;
}
return true;
}
Expand Down Expand Up @@ -775,7 +800,7 @@ Client::dispatcher()
should_send_error_if_canceled_ = false;
if (!send(std::string_view(obuf_.p, p - obuf_.p)))
return false;
char buf[512];
char buf[16384];
size_t i, chunk;
for (i = 0; i < size; i += chunk) {
chunk = size - i;
Expand Down
5 changes: 5 additions & 0 deletions llamafile/server/main.1
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,11 @@ supported by the host operating system. The default keepalive is 5.
Size of HTTP output buffer size, in bytes. Default is 1048576.
.It Fl Fl http-ibuf-size Ar N
Size of HTTP input buffer size, in bytes. Default is 1048576.
.It Fl Fl http-write-timeout Ar MS
Socket write timeout in milliseconds. When sending data to a client, if
the socket buffer is full and the client is not reading, the server will
wait up to this many milliseconds for the socket to become writable before
closing the connection. Default is 60000 (60 seconds).
.It Fl Fl chat-template Ar NAME
Specifies or overrides chat template for model.
.Pp
Expand Down
7 changes: 7 additions & 0 deletions llamafile/server/main.1.asc
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,13 @@
--http-ibuf-size N
Size of HTTP input buffer size, in bytes. Default is 1048576.

--http-write-timeout MS
Socket write timeout in milliseconds. When sending data to a
client, if the socket buffer is full and the client is not
reading, the server will wait up to this many milliseconds for
the socket to become writable before closing the connection.
Default is 60000 (60 seconds).

--chat-template NAME
Specifies or overrides chat template for model.

Expand Down
7 changes: 0 additions & 7 deletions llamafile/server/worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,6 @@ Worker::begin()
tokens = tokenbucket_acquire(client_.client_ip_);
server_->lock();
dll_remove(&server_->idle_workers, &elem_);
if (dll_is_empty(server_->idle_workers)) {
Dll* slowbro;
if ((slowbro = dll_last(server_->active_workers))) {
SLOG("all threads active! dropping oldest client");
WORKER(slowbro)->kill();
}
}
working_ = true;
if (tokens > FLAG_token_burst) {
dll_make_last(&server_->active_workers, &elem_);
Expand Down
49 changes: 48 additions & 1 deletion llamafile/server/writev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,20 @@
// limitations under the License.

#include "llamafile/server/log.h"
#include "llamafile/llamafile.h"
#include "utils.h"
#include <cerrno>
#include <poll.h>
#include <string_view>
#include <vector>

namespace lf {
namespace server {

ssize_t
safe_writev(int fd, const iovec* iov, int iovcnt)
{
// Security check for binary content in headers
for (int i = 0; i < iovcnt; ++i) {
bool has_binary = false;
size_t n = iov[i].iov_len;
Expand All @@ -39,7 +43,50 @@ safe_writev(int fd, const iovec* iov, int iovcnt)
return -1;
}
}
return writev(fd, iov, iovcnt);

ssize_t total = 0;
// Create a mutable copy of iovecs to track progress
std::vector<iovec> copy(iov, iov + iovcnt);
int i = 0; // Current iovec index

while (i < iovcnt) {
ssize_t sent = writev(fd, copy.data() + i, iovcnt - i);
if (sent == -1) {
if (errno == EINTR)
continue;
if (errno == EAGAIN || errno == EWOULDBLOCK) {
struct pollfd pfd = { .fd = fd, .events = POLLOUT };
int rc = poll(&pfd, 1, FLAG_http_write_timeout);
if (rc == 0) {
errno = ETIMEDOUT;
return -1;
}
if (rc == -1) {
if (errno == EINTR)
continue;
return -1;
}
continue;
}
return -1;
}

total += sent;
size_t got = sent;

// Advance the iovecs based on bytes written
while (got > 0 && i < iovcnt) {
if (got >= copy[i].iov_len) {
got -= copy[i].iov_len;
++i;
} else {
copy[i].iov_base = (char*)copy[i].iov_base + got;
copy[i].iov_len -= got;
got = 0;
}
}
}
return total;
}

} // namespace server
Expand Down