sevencat 发表于 2013-12-15 21:38:56

thrift的asio实现

thrift是目前能找到的轻量级,高性能,非常灵活的一个rpc。
官方的实现是用libevent2+多线程做为网络接口,网上也有人有了asio的实现。

写了个简单的
#pragma once
#include "Api.h"
#include <iostream>
#include <stdexcept>
#include <sstream>
#include "rpc_conn.h"

using namespace apache::thrift;
using namespace apache::thrift::protocol;
using namespace apache::thrift::transport;

class RpcJgServer : virtual public jgrpc::ApiIf
{
public:
    explicit RpcJgServer(const std::string& address, int port,std::size_t thread_pool_size);

    void run();

    void stop();

private:
    void start_accept();
    void handle_accept(RpcConnectionPtr conn,const asio::error_code& e);

    std::size_t thread_pool_size_;
    asio::io_service io_service_;

    asio::ip::tcp::acceptor acceptor_;

    boost::shared_ptr<TProcessor> processor_;
    boost::shared_ptr<TProtocolFactory> input_protocol_factory_;
    boost::shared_ptr<TProtocolFactory> output_protocol_factory_;
    boost::shared_ptr<jgrpc::ApiIf> evt_If;

public:
    std::set<RpcConnectionPtr> clients_;//已登录的客户端
    asio::detail::mutex mutex_;//是一把锁

public:
    void jgrpc::ApiIf::ping(void){};
};
#include "stdafx.h"
#include "rpc_jg.h"

RpcJgServer::RpcJgServer(const std::string& address, int port,std::size_t thread_pool_size)
               : thread_pool_size_(thread_pool_size),
               acceptor_(io_service_, asio::ip::tcp::endpoint(asio::ip::address_v4::from_string(address), port))
{
    boost::shared_ptr<TProtocolFactory> protocolFactory(new TBinaryProtocolFactory());
    input_protocol_factory_=protocolFactory;
    output_protocol_factory_=protocolFactory;
    evt_If.reset(this);
    processor_.reset(new jgrpc::ApiProcessor(evt_If));
    start_accept();
}

void RpcJgServer::start_accept()
{
    RpcConnectionPtr new_connection;
    new_connection.reset(new RpcConnection(io_service_, processor_,input_protocol_factory_, output_protocol_factory_));
    acceptor_.async_accept(new_connection->socket(),
      boost::bind(&RpcJgServer::handle_accept, this,
      new_connection,
      asio::placeholders::error));   
}
void RpcJgServer::run()
{
    std::vector<boost::shared_ptr<asio::thread> > threads;
    for (std::size_t i = 0; i < thread_pool_size_; ++i)
    {
      boost::shared_ptr<asio::thread> thread(new asio::thread(
            boost::bind(&asio::io_service::run, &io_service_)));
      threads.push_back(thread);
    }

    for (std::size_t i = 0; i < threads.size(); ++i) {
      threads->join();
    }
}

void RpcJgServer::stop()
{
    io_service_.stop();
}

void RpcJgServer::handle_accept(RpcConnectionPtr conn,const asio::error_code& e)
{
    if (!e)
    {
      //这里进行处理
      conn->start();
    }
    start_accept();   
}#pragma once

using namespace apache::thrift;
using namespace apache::thrift::protocol;
using namespace apache::thrift::transport;

class RpcConnection: public boost::enable_shared_from_this<RpcConnection>,private boost::noncopyable
{
public:
        explicit RpcConnection(asio::io_service& io_service,
                boost::shared_ptr<TProcessor> processor,
                boost::shared_ptr<TProtocolFactory> inputProtocolFactory,
                boost::shared_ptr<TProtocolFactory> outputProtocolFactory);

        asio::ip::tcp::socket& socket()
        {
                return socket_;
        }
        void start()
        {
                asio::async_read(socket_, asio::buffer(length_),
                        strand_.wrap(
                        boost::bind(&RpcConnection::handle_read_length, shared_from_this(),
                        asio::placeholders::error)));
        }

        static RpcConnection *get_current_conn();

private:
        void handle_read_length(const asio::error_code& e);
        void handle_read_frame(int32_t sz, const asio::error_code& e);
        void handle_write(const asio::error_code& e);
        asio::io_service::strand strand_;
        asio::ip::tcp::socket socket_;
        boost::array<uint8_t, 4> length_;
        boost::shared_ptr<TProcessor> processor_;
        boost::shared_ptr<TProtocolFactory> inputProtocolFactory_;
        boost::shared_ptr<TProtocolFactory> outputProtocolFactory_;
        std::vector<uint8_t> frame_;
        asio::io_service& io_service_;
       
};

typedef boost::shared_ptr<RpcConnection> RpcConnectionPtr;#include "stdafx.h"
#include "rpc_conn.h"
#include <boost/thread/tss.hpp>

static boost::thread_specific_ptr<RpcConnection> tls_sessiondata;
static boost::shared_ptr<boost::detail::tss_cleanup_function> tls_cleanup;

RpcConnection *RpcConnection::get_current_conn()
{
        RpcConnection *cursession=tls_sessiondata.get();
        return cursession;
}

RpcConnection::RpcConnection(asio::io_service& io_service,
                                           boost::shared_ptr<TProcessor> processor,
                                           boost::shared_ptr<TProtocolFactory> inputProtocolFactory,
                                           boost::shared_ptr<TProtocolFactory> outputProtocolFactory
                                           )
                                           : strand_(io_service),
                                           socket_(io_service),
                                           processor_(processor),
                                           inputProtocolFactory_(inputProtocolFactory),
                                           outputProtocolFactory_(outputProtocolFactory),
                                           io_service_(io_service)
{
}

void RpcConnection::handle_read_length(const asio::error_code& e)
{
        if (!e)
        {
                int32_t sz =
                        ((length_ & 0xff) << 24) |
                        ((length_ & 0xff) << 16) |
                        ((length_ & 0xff) <<8) |
                        ((length_ & 0xff));

                if (sz < 0)
                {
                        std::cerr << "Frame length is negative. Are you sure your client uses a TFramedTransport?" << std::endl;
                        delete this;
                }
                else if (sz == 0)
                {
                        std::cerr << "Frame length is zero, weird" << std::endl;
                        delete this;
                }
                else
                {
                        frame_.resize(sz);
                        asio::async_read(socket_,
                                asio::buffer(frame_, sz),
                                strand_.wrap(
                                boost::bind(&RpcConnection::handle_read_frame, shared_from_this(), sz,asio::placeholders::error)));
                }
        }
}

//放到tls里面
void RpcConnection::handle_read_frame(int32_t sz, const asio::error_code& e)
{
        if (!e)
        {
                boost::detail::set_tss_data(&tls_sessiondata,tls_cleanup,this,false);
                boost::shared_ptr<TMemoryBuffer> inputTransport(new TMemoryBuffer(&frame_, sz));
                boost::shared_ptr<TProtocol> inputProtocol = inputProtocolFactory_->getProtocol(inputTransport);

                boost::shared_ptr<TMemoryBuffer> outputTransport(new TMemoryBuffer);
                boost::shared_ptr<TFramedTransport> outputFramedTransport(new TFramedTransport(outputTransport));
                boost::shared_ptr<TProtocol> outputProtocol = outputProtocolFactory_->getProtocol(outputFramedTransport);

                processor_->process(inputProtocol, outputProtocol,NULL);

                uint32_t outframeSize;
                uint8_t *outframe;

                outputTransport->getBuffer(&outframe, &outframeSize);

                asio::async_write(socket_,
                        asio::buffer(outframe, outframeSize),
                        strand_.wrap(boost::bind(&RpcConnection::handle_write, shared_from_this(), asio::placeholders::error)));

        }
}

void RpcConnection::handle_write(const asio::error_code& e)
{
        if (!e)
        {
                asio::async_read(socket_, asio::buffer(length_),
                        strand_.wrap(boost::bind(&RpcConnection::handle_read_length, shared_from_this(), asio::placeholders::error)));
        }
}

sevencat 发表于 2013-12-15 21:39:43

没有实现servercontext接口,因为在这个框架下你可以直接通过tls直接获取connection。
本代码没有经过哪怕简单的测试。

sevencat 发表于 2013-12-15 21:43:09

注:TFramedTransport这个东东其实就是在传输前加上一个长度。所以代码中读的时候先读一个整数,写的时候就直接调用的
boost::shared_ptr<TFramedTransport> outputFramedTransport(new TFramedTransport(outputTransport));
页: [1]
查看完整版本: thrift的asio实现