处理一个带有一次性报头的Akka流



我有一个接收TCP套接字连接的应用程序,该连接将以以下形式发送数据:

n{json}bbbbbbbbbb...

其中n是以下json的字节长度,json可能是{'splitEvery': 5}之类的东西,这将指示我如何分解和处理接下来的潜在无限字节字符串。

我想用Scala中的Akka处理这个流。我认为 streams是正确的工具,但是我很难找到一个使用具有不同处理阶段的流的例子。大多数流似乎一遍又一遍地做同样的事情,就像这里的prefixAndTail示例。这与我想如何处理流的n{json}部分非常接近,但不同之处在于我只需要在每个连接中执行一次,然后进入处理的不同"阶段"。

谁能给我举一个使用Akka流的不同阶段的例子?

下面是GraphStage处理ByteString s流:

  • 从报头中提取块大小
  • 发出指定块大小的ByteString s
import akka.stream.{Attributes, FlowShape, Inlet, Outlet}
import akka.stream.stage.{GraphStage, GraphStageLogic, InHandler, OutHandler}
import akka.util.ByteString
class PreProcessor extends GraphStage[FlowShape[ByteString, ByteString]] {
  val in: Inlet[ByteString] = Inlet("ParseHeader.in")
  val out: Outlet[ByteString] = Outlet("ParseHeader.out")
  override val shape = FlowShape.of(in, out)
  override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
    new GraphStageLogic(shape) {
      var buffer = ByteString.empty
      var chunkSize: Option[Int] = None
      private var upstreamFinished = false
      private val headerPattern = """^d+{"splitEvery": (d+)}""".r
      /**
        * @param data The data to parse.
        * @return The chunk size and header size if the header
        * could be parsed.
        */
      def parseHeader(data: ByteString): Option[(Int, Int)] =
      headerPattern.
        findFirstMatchIn(data.decodeString("UTF-8")).
        map { mtch => (mtch.group(1).toInt, mtch.end) }
      setHandler(out, new OutHandler {
        override def onPull(): Unit = {
          if (isClosed(in)) emit()
          else pull(in)
        }
      })
      setHandler(in, new InHandler {
        override def onPush(): Unit = {
          val elem = grab(in)
          buffer ++= elem
          if (chunkSize.isEmpty) {
            parseHeader(buffer) foreach { case (chunk, headerSize) =>
              chunkSize = Some(chunk)
              buffer = buffer.drop(headerSize)
            }
          }
          emit()
        }
        override def onUpstreamFinish(): Unit = {
          upstreamFinished = true
          if (chunkSize.isEmpty || buffer.isEmpty) completeStage()
          else {
            if (isAvailable(out)) emit()
          }
        }
      })
      private def continue(): Unit =
        if (isClosed(in)) completeStage()
        else pull(in)
      private def emit(): Unit = {
        chunkSize match {
          case None => continue()
          case Some(size) =>
            if (upstreamFinished && buffer.isEmpty ||
               !upstreamFinished && buffer.size < size) {
              continue()
            } else {
              val (chunk, nextBuffer) = buffer.splitAt(size)
              buffer = nextBuffer
              push(out, chunk)
            }
        }
      }
    }
}

和说明用法的测试用例:

import akka.actor.ActorSystem
import akka.stream._
import akka.stream.scaladsl.Source
import akka.util.ByteString
import org.scalatest._
import scala.concurrent.Await
import scala.concurrent.duration._
import scala.util.Random
class PreProcessorSpec extends FlatSpec {
  implicit val system = ActorSystem("Test")
  implicit val materializer = ActorMaterializer()
  val random = new Random
  "" should "" in {
    def splitRandom(s: String, n: Int): List[String] = s match {
      case "" => Nil
      case s =>
        val (head, tail) = s splitAt random.nextInt(n)
        head :: splitRandom(tail, n)
    }
    val input = """17{"splitEvery": 5}aaaaabbbbbcccccddd"""
    val strings = splitRandom(input, 7)
    println(strings.map(s => s"[$s]").mkString(" ") + "n")
    val future = Source.fromIterator(() => strings.iterator).
      map(ByteString(_)).
      via(new PreProcessor()).
      map(_.decodeString("UTF-8")).
      runForeach(println)
    Await.result(future, 5 seconds)
  }
}

示例输出:

[17{"] [splitE] [very"] [] [: 5}] [aaaaa] [bbb] [bbcccc] [] [cddd]
aaaaa
bbbbb
ccccc
ddd

由于块大小取决于流的内容,但在处理流数据之前必须实现所有处理阶段,因此您不能轻松使用Source.group(chunkSize)这样的方便方法。我建议从流的开头剥离元数据(使用与Akka流不同的方法),并将流的其余部分提供给Source.group(chunkSize)

或者您可以使用状态机折叠/扫描流,但这要麻烦得多:

implicit val system = ActorSystem("Test")
implicit val materializer = ActorMaterializer()
val input = """17{"splitEvery": 5}aaaaabbbbbccccc"""
def getChunkSize(json: String) = 5 // dummy implementation
sealed trait State
case class GetLength(number: String) extends State
case class GetJson(n: Int, json: String) extends State
case class ProcessData(chunkSize: Int, s: String) extends State
type Out = (State, Option[String])
val future = Source.fromIterator(() => input.iterator).
  scan[Out]((GetLength(""), None)) {
    case ((GetLength(s), _), e) if e.isDigit => (GetLength(s + e), None)
    case ((GetLength(s), _), e) => (GetJson(s.toInt - 1, e.toString), None)
    case ((GetJson(0, json), _), e) => (ProcessData(getChunkSize(json), e.toString), None)
    case ((GetJson(n, json), _), e) => (GetJson(n - 1, json + e), None)
    case ((ProcessData(chunkSize, s), _), e) if s.length == chunkSize - 1 => (ProcessData(chunkSize, ""), Some(s + e))
    case ((ProcessData(chunkSize, s), _), e) => (ProcessData(chunkSize, s + e), None)
  }.
  collect { case (_, Some(s)) => s }.
  runForeach(println)
println(Await.result(future, 1 second))
// aaaaa
// bbbbb
// ccccc

根据记录,这里有一种无法工作的方法,因为takeWhile消耗迭代器的下一个元素(当_.isDigit失败时),这仍然需要后续的JSON解析阶段:

val it = input.iterator
def nextSource = Source.fromIterator(() => it)
implicit class Stringify[+Out, +Mat](val source: Source[Out, Mat]) {
  def stringify = source.runFold("")(_ + _)
}
val future2 = nextSource.
  takeWhile(_.isDigit).
  stringify.
  map(_.toInt).
  map { l =>
    nextSource.
      take(l).
      stringify.
      map(getChunkSize).
      map { chunkSize =>
        nextSource.
          grouped(chunkSize).
          map(_.mkString).
          runForeach(println)
      }
  }
println(Await.result(future2, 1 second))
// aaaab
// bbbbc
// cccc

我需要处理一个2字节长度的头,后面跟着数据。这可能有助于使用GraphStage根据长度前缀处理/累积数据的逻辑。使用了各种在线akka文档,并在前面提供的解决方案中提出了在java中实现它的想法。

Java代码
package com.example;
import akka.Done;
import akka.NotUsed;
import akka.actor.typed.ActorSystem;
import akka.actor.typed.scaladsl.Behaviors;
import akka.stream.*;
import akka.stream.javadsl.Flow;
import akka.stream.javadsl.Sink;
import akka.stream.javadsl.Source;
import akka.stream.stage.GraphStage;
import akka.stream.stage.GraphStageLogic;
import akka.stream.stage.InHandler;
import akka.stream.stage.OutHandler;
import akka.util.ByteString;
import scala.Tuple2;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletionStage;
//import akka.stream.scaladsl.Sink;
public class TwoByteLengthFramingFlow extends GraphStage<FlowShape<ByteString, ByteString>> {
    private final Inlet<ByteString> inlet = Inlet.create("TwoByteLengthFraming.in");
    private final Outlet<ByteString> outlet = Outlet.create("TwoByteLengthFraming.out");
    FlowShape<ByteString, ByteString> shape = FlowShape.of(inlet, outlet);
    public static void main(String[] args) {
        ActorSystem system = ActorSystem.create(Behaviors.empty(), "Blah");
        byte[] b0 = {0, 4, 'A', 'B', 'C', 'D', 0}; // The last 0 to simulate incomplete length in packet 1
        byte[] b1 = {4, 'E', 'F', 'G', 'H', 'A',9};// the last 9 is to simulate stream end but graphstage has one byte in the buffer (simulating incomplete message)
        ByteString x0 = ByteString.fromArray(b0);
        ByteString x1 = ByteString.fromArray(b1);
        List<ByteString> l = new ArrayList<>();
        // simulate messages as tcp streams where data is coming in and we need assemble the message from the packets.
        l.add(x0);
        l.add(x1);
        // ByteString[] b =
        // {ByteString.fromArray({0x0}])'A','B','C','D',0,4,'E','F','G'};
        Graph<FlowShape<ByteString, ByteString>, NotUsed> flowgraph = new TwoByteLengthFramingFlow();
        Flow<ByteString, ByteString, NotUsed> flow = Flow.fromGraph(flowgraph);
        Sink<ByteString, CompletionStage<Done>> printSink = Sink.foreach(msg -> System.out.println(msg.utf8String()));
        Source.from(l)
              .via(flow)
              .to(printSink)
              .run(system);

        // Just to see what happens when the sink cancelled terminates the flow
//        Source.from(l)
//              .via(flow)
//              .to(Sink.cancelled())
//              .run(system);
    }
    Flow<ByteString, ByteString, NotUsed> getFlow() {
        Graph<FlowShape<ByteString, ByteString>, NotUsed> flowgraph = new TwoByteLengthFramingFlow();
        Flow<ByteString, ByteString, NotUsed> flow = Flow.fromGraph(flowgraph);
        return flow;
    }
    @Override
    public FlowShape<ByteString, ByteString> shape() {
        return shape;
    }
    @Override
    public GraphStageLogic createLogic(Attributes inheritedAttributes) throws Exception {
        // TODO Auto-generated method stub
        return new GraphStageLogicExtension(shape);
    }
    private final class GraphStageLogicExtension extends GraphStageLogic {
        private final List<ByteString> messages = new ArrayList<>();
        protected ByteString buffer = ByteString.emptyByteString();
        private GraphStageLogicExtension(Shape shape) {
            super(shape);
            setHandler(inlet, new InHandler() {
                @Override
                public void onPush() throws Exception {
                    System.out.println("onPush()");
                    // upstream pushed data, our onPush got called
                    // All incoming bytes are added to the buffer. We concat as there may incomplete messages in the bugger
                    buffer = buffer.concat(grab(inlet));
                    // recursively extract as many messages as you can [len+message+len+message....]
                    // Extract the messages you can and data can remain in the buffer. If you cannot extract len+data as all the bytes aren't available (either 2 bytes of length aren't available
                    // or 2 bytes of length are available but the  equivalent data for that length is not available
                    // , append that into the buffer and extract what you can. The net time data arrives it will get added to the buffer and we can try again.
                    extractMessages();
                    // emit extracted messages
                    emitChunk();
                    //
                    //pull(inlet);
                }
                @Override
                public void onUpstreamFinish() throws Exception {
                    System.out.println("onUpstreamFinish()");
                    // upstream signalled its done
                    if (buffer.size() == 0 && messages.size() == 0) {
                        // no incomplete message in buffer
                        completeStage();
                    }
                    else {
                        // There are elements left in buffer, so
                        // we keep accepting downstream pulls and push from buffer until emptied.
                        //
                        // It might be though, that the upstream finished while it was pulled, in which
                        // case we will not get an onPull from the downstream, because we already had one.
                        // In that case we need to emit from the buffer.
                        if (isAvailable(outlet))
                            emitChunk();
                    }

                }
                private void emitChunk() {
                    System.out.println("emitChunk()");
                    // If we don't have extracted messaged
                    if (messages.size() <= 0) {
                        // if the upstream closed the inlet, we are done
                        if (isClosed(inlet)) {
                            completeStage();
                        }
                        // we can pull to get more data
                        else {
                            System.out.println("pull()");
                            pull(inlet);
                        }
                    }
                    else {
                        // we have messages so send one and remove it from the list.
                        System.out.println("emit()");
                        emit(outlet, messages.remove(0));
                    }
                }
            });
            setHandler(outlet, new OutHandler() {
                @Override
                public void onPull() throws Exception {
                    System.out.println("onPull()");
                    //downstream pulled so we got onPull
                    if (messages.size() > 0) {
                        // if we have messages we can push them
                        System.out.println("push()");
                        push(outlet, messages.remove(0));
                    }
                    else
                        // if we don't have messages to push, we need more dta from upstream, we will do a pull and upstream can react to it and push data and our onPush will get called
                        System.out.println("pull()");
                    pull(inlet);
                }
                @Override
                public void onDownstreamFinish() throws Exception {
                    System.out.println("Downstream  Finished");
                    OutHandler.super.onDownstreamFinish();
                }
            });
        }
        protected void extractMessages() {
            Tuple2<ByteString, ByteString> lengthDataTuple = buffer.splitAt(2);
            int messageLength = getLength(lengthDataTuple._1);
            if ((messageLength < 0) || (lengthDataTuple._2.take(messageLength)
                                                          .size() != messageLength)) {
                return;
            }
            if (messageLength == 0) {
                //maybe its a 0 byte ping message. Let the next stage handle empty bytestring.
                messages.add(ByteString.emptyByteString());
            }
            else {
                messages.add(lengthDataTuple._2.take(messageLength));
            }
            //Update buffer by removing messages that could be extracted
            buffer = buffer.drop(2 + messageLength);
            // recurse, till we can extract whatever is possible
            extractMessages();
        }
        private int getLength(ByteString header) {
/*
If length was
255 = 0xFF : b[0] = 0xff b[1] = 0x00,
256 = 0x100 : b[0] = 0x00 b[1] = 0x01,
257= 0x101 : b[0] = 0x01 b[1] = 0x01
Basically length div 256 is in b[1] and length mod 256 is in b[0]
 */
            byte[] b = header.toArray();
            if (b.length == 2) {
                return (((b[0]) & 0xFF) << 8) | ((b[1]) & 0xFF);
            }
            return -1;
        }
    }
}

最新更新