263 lines
9.9 KiB
C++
263 lines
9.9 KiB
C++
#include "message.h"
|
|
#include <stdexcept>
|
|
#include <cstring>
|
|
#include <iostream>
|
|
#include <iomanip>
|
|
|
|
namespace scar {
|
|
|
|
constexpr uint8_t PROTOCOL_VERSION = 1;
|
|
|
|
Message::Message(MessageType type) {
|
|
header_.length = sizeof(MessageHeader);
|
|
header_.type = type;
|
|
header_.version = PROTOCOL_VERSION;
|
|
header_.reserved = 0;
|
|
}
|
|
|
|
std::vector<uint8_t> Message::serialize() const {
|
|
std::vector<uint8_t> buffer;
|
|
buffer.resize(header_.length);
|
|
|
|
// Copy header
|
|
std::memcpy(buffer.data(), &header_, sizeof(MessageHeader));
|
|
|
|
// Copy payload
|
|
if (!payload_.empty()) {
|
|
std::memcpy(buffer.data() + sizeof(MessageHeader), payload_.data(), payload_.size());
|
|
}
|
|
|
|
return buffer;
|
|
}
|
|
|
|
void Message::setPayload(const std::vector<uint8_t>& data) {
|
|
payload_ = data;
|
|
header_.length = sizeof(MessageHeader) + payload_.size();
|
|
}
|
|
|
|
std::unique_ptr<Message> Message::deserialize(const std::vector<uint8_t>& data) {
|
|
if (data.size() < sizeof(MessageHeader)) {
|
|
throw std::runtime_error("Invalid message: too short");
|
|
}
|
|
|
|
MessageHeader header;
|
|
std::memcpy(&header, data.data(), sizeof(MessageHeader));
|
|
|
|
if (header.length != data.size()) {
|
|
throw std::runtime_error("Invalid message: length mismatch");
|
|
}
|
|
|
|
std::vector<uint8_t> payload(data.begin() + sizeof(MessageHeader), data.end());
|
|
|
|
// Dispatch to specific message types
|
|
switch (header.type) {
|
|
case MessageType::LOGIN_REQUEST:
|
|
return LoginRequest::deserialize(payload);
|
|
case MessageType::LOGIN_RESPONSE:
|
|
return LoginResponse::deserialize(payload);
|
|
case MessageType::TEXT_MESSAGE:
|
|
return TextMessage::deserialize(payload);
|
|
default:
|
|
throw std::runtime_error("Unknown message type");
|
|
}
|
|
}
|
|
|
|
// LoginRequest implementation
|
|
LoginRequest::LoginRequest(const std::string& username, const std::string& password)
|
|
: Message(MessageType::LOGIN_REQUEST), username_(username), password_(password) {
|
|
std::cout << "LoginRequest constructor - Username: '" << username << "', Password length: " << password.length() << std::endl;
|
|
}
|
|
|
|
std::vector<uint8_t> LoginRequest::serialize() const {
|
|
std::cout << "LoginRequest::serialize - Username: '" << username_ << "', Password length: " << password_.length() << std::endl;
|
|
|
|
std::vector<uint8_t> payload;
|
|
|
|
// Username length + username
|
|
uint16_t username_len = username_.size();
|
|
std::cout << " Serializing username_len: " << username_len << std::endl;
|
|
payload.insert(payload.end(), reinterpret_cast<const uint8_t*>(&username_len),
|
|
reinterpret_cast<const uint8_t*>(&username_len) + sizeof(username_len));
|
|
payload.insert(payload.end(), username_.begin(), username_.end());
|
|
|
|
// Password length + password
|
|
uint16_t password_len = password_.size();
|
|
std::cout << " Serializing password_len: " << password_len << std::endl;
|
|
payload.insert(payload.end(), reinterpret_cast<const uint8_t*>(&password_len),
|
|
reinterpret_cast<const uint8_t*>(&password_len) + sizeof(password_len));
|
|
payload.insert(payload.end(), password_.begin(), password_.end());
|
|
|
|
std::cout << " Total payload size: " << payload.size() << std::endl;
|
|
|
|
const_cast<LoginRequest*>(this)->setPayload(payload);
|
|
return Message::serialize();
|
|
}
|
|
|
|
std::unique_ptr<LoginRequest> LoginRequest::deserialize(const std::vector<uint8_t>& payload) {
|
|
std::cout << "LoginRequest::deserialize - Payload size: " << payload.size() << std::endl;
|
|
|
|
// Debug: print raw bytes
|
|
std::cout << " Raw payload bytes: ";
|
|
for (size_t i = 0; i < std::min(payload.size(), size_t(20)); ++i) {
|
|
std::cout << std::hex << std::setw(2) << std::setfill('0') << (int)payload[i] << " ";
|
|
}
|
|
std::cout << std::dec << std::endl;
|
|
|
|
size_t offset = 0;
|
|
|
|
// Read username
|
|
if (offset + sizeof(uint16_t) > payload.size()) {
|
|
throw std::runtime_error("Invalid LoginRequest: truncated username length");
|
|
}
|
|
uint16_t username_len;
|
|
std::memcpy(&username_len, payload.data() + offset, sizeof(username_len));
|
|
offset += sizeof(username_len);
|
|
|
|
std::cout << " Username length (raw bytes): " << (int)payload[0] << " " << (int)payload[1] << std::endl;
|
|
std::cout << " Username length (uint16_t): " << username_len << std::endl;
|
|
|
|
if (offset + username_len > payload.size()) {
|
|
throw std::runtime_error("Invalid LoginRequest: truncated username");
|
|
}
|
|
std::string username(payload.begin() + offset, payload.begin() + offset + username_len);
|
|
offset += username_len;
|
|
|
|
std::cout << " Username: '" << username << "'" << std::endl;
|
|
|
|
// Read password
|
|
if (offset + sizeof(uint16_t) > payload.size()) {
|
|
throw std::runtime_error("Invalid LoginRequest: truncated password length");
|
|
}
|
|
uint16_t password_len;
|
|
std::memcpy(&password_len, payload.data() + offset, sizeof(password_len));
|
|
offset += sizeof(password_len);
|
|
|
|
std::cout << " Password length: " << password_len << std::endl;
|
|
|
|
if (offset + password_len > payload.size()) {
|
|
throw std::runtime_error("Invalid LoginRequest: truncated password");
|
|
}
|
|
std::string password(payload.begin() + offset, payload.begin() + offset + password_len);
|
|
|
|
std::cout << " Password: (hidden, length=" << password.length() << ")" << std::endl;
|
|
|
|
return std::make_unique<LoginRequest>(username, password);
|
|
}
|
|
|
|
// LoginResponse implementation
|
|
LoginResponse::LoginResponse(bool success, const std::string& token, ErrorCode error)
|
|
: Message(MessageType::LOGIN_RESPONSE), success_(success), token_(token), error_(error) {}
|
|
|
|
std::vector<uint8_t> LoginResponse::serialize() const {
|
|
std::vector<uint8_t> payload;
|
|
|
|
// Success flag
|
|
payload.push_back(success_ ? 1 : 0);
|
|
|
|
// Error code
|
|
uint16_t error_code = static_cast<uint16_t>(error_);
|
|
payload.insert(payload.end(), reinterpret_cast<const uint8_t*>(&error_code),
|
|
reinterpret_cast<const uint8_t*>(&error_code) + sizeof(error_code));
|
|
|
|
// Token length + token
|
|
uint16_t token_len = token_.size();
|
|
payload.insert(payload.end(), reinterpret_cast<const uint8_t*>(&token_len),
|
|
reinterpret_cast<const uint8_t*>(&token_len) + sizeof(token_len));
|
|
payload.insert(payload.end(), token_.begin(), token_.end());
|
|
|
|
const_cast<LoginResponse*>(this)->setPayload(payload);
|
|
return Message::serialize();
|
|
}
|
|
|
|
std::unique_ptr<LoginResponse> LoginResponse::deserialize(const std::vector<uint8_t>& payload) {
|
|
size_t offset = 0;
|
|
|
|
// Read success flag
|
|
if (offset >= payload.size()) {
|
|
throw std::runtime_error("Invalid LoginResponse: missing success flag");
|
|
}
|
|
bool success = payload[offset++] != 0;
|
|
|
|
// Read error code
|
|
if (offset + sizeof(uint16_t) > payload.size()) {
|
|
throw std::runtime_error("Invalid LoginResponse: truncated error code");
|
|
}
|
|
uint16_t error_code;
|
|
std::memcpy(&error_code, payload.data() + offset, sizeof(error_code));
|
|
offset += sizeof(error_code);
|
|
ErrorCode error = static_cast<ErrorCode>(error_code);
|
|
|
|
// Read token
|
|
if (offset + sizeof(uint16_t) > payload.size()) {
|
|
throw std::runtime_error("Invalid LoginResponse: truncated token length");
|
|
}
|
|
uint16_t token_len;
|
|
std::memcpy(&token_len, payload.data() + offset, sizeof(token_len));
|
|
offset += sizeof(token_len);
|
|
|
|
if (offset + token_len > payload.size()) {
|
|
throw std::runtime_error("Invalid LoginResponse: truncated token");
|
|
}
|
|
std::string token(payload.begin() + offset, payload.begin() + offset + token_len);
|
|
|
|
return std::make_unique<LoginResponse>(success, token, error);
|
|
}
|
|
|
|
// TextMessage implementation
|
|
TextMessage::TextMessage(const std::string& sender, const std::string& content)
|
|
: Message(MessageType::TEXT_MESSAGE), sender_(sender), content_(content) {}
|
|
|
|
std::vector<uint8_t> TextMessage::serialize() const {
|
|
std::vector<uint8_t> payload;
|
|
|
|
// Sender length + sender
|
|
uint16_t sender_len = sender_.size();
|
|
payload.insert(payload.end(), reinterpret_cast<const uint8_t*>(&sender_len),
|
|
reinterpret_cast<const uint8_t*>(&sender_len) + sizeof(sender_len));
|
|
payload.insert(payload.end(), sender_.begin(), sender_.end());
|
|
|
|
// Content length + content
|
|
uint16_t content_len = content_.size();
|
|
payload.insert(payload.end(), reinterpret_cast<const uint8_t*>(&content_len),
|
|
reinterpret_cast<const uint8_t*>(&content_len) + sizeof(content_len));
|
|
payload.insert(payload.end(), content_.begin(), content_.end());
|
|
|
|
const_cast<TextMessage*>(this)->setPayload(payload);
|
|
return Message::serialize();
|
|
}
|
|
|
|
std::unique_ptr<TextMessage> TextMessage::deserialize(const std::vector<uint8_t>& payload) {
|
|
size_t offset = 0;
|
|
|
|
// Read sender
|
|
if (offset + sizeof(uint16_t) > payload.size()) {
|
|
throw std::runtime_error("Invalid TextMessage: truncated sender length");
|
|
}
|
|
uint16_t sender_len;
|
|
std::memcpy(&sender_len, payload.data() + offset, sizeof(sender_len));
|
|
offset += sizeof(sender_len);
|
|
|
|
if (offset + sender_len > payload.size()) {
|
|
throw std::runtime_error("Invalid TextMessage: truncated sender");
|
|
}
|
|
std::string sender(payload.begin() + offset, payload.begin() + offset + sender_len);
|
|
offset += sender_len;
|
|
|
|
// Read content
|
|
if (offset + sizeof(uint16_t) > payload.size()) {
|
|
throw std::runtime_error("Invalid TextMessage: truncated content length");
|
|
}
|
|
uint16_t content_len;
|
|
std::memcpy(&content_len, payload.data() + offset, sizeof(content_len));
|
|
offset += sizeof(content_len);
|
|
|
|
if (offset + content_len > payload.size()) {
|
|
throw std::runtime_error("Invalid TextMessage: truncated content");
|
|
}
|
|
std::string content(payload.begin() + offset, payload.begin() + offset + content_len);
|
|
|
|
return std::make_unique<TextMessage>(sender, content);
|
|
}
|
|
|
|
} // namespace scar
|