Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
Himpoke authored Nov 3, 2024
2 parents dcbab69 + c6eb3a3 commit f01c59e
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 41 deletions.
27 changes: 19 additions & 8 deletions src/main/java/org/red5/server/net/rtmp/RTMPHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ public class RTMPHandler extends BaseRTMPHandler {
private static final String HIGH_RESOURCE_USAGE = "current system resources not enough";
private static final String INVALID_STREAM_NAME = "stream name is invalid. Don't use special characters.";


/**
* Status object service.
*/
Expand Down Expand Up @@ -113,6 +112,10 @@ public void setStatusObjectService(StatusObjectService statusObjectService) {
this.statusObjectService = statusObjectService;
}

public StatusObjectService getStatusObjectService() {
return statusObjectService;
}

public boolean isUnvalidatedConnectionAllowed() {
return unvalidatedConnectionAllowed;
}
Expand Down Expand Up @@ -206,7 +209,7 @@ protected void invokeCall(RTMPConnection conn, IServiceCall call) {
* Server-side service object
* @return true if the call was performed, otherwise false
*/
private boolean invokeCall(RTMPConnection conn, IServiceCall call, Object service) {
public boolean invokeCall(RTMPConnection conn, IServiceCall call, Object service) {
final IScope scope = conn.getScope();
final IContext context = scope.getContext();
if (log.isTraceEnabled()) {
Expand Down Expand Up @@ -283,15 +286,23 @@ protected void onCommand(RTMPConnection conn, Channel channel, Header source, IC
call.getArguments()[0] = streamId;
}

if(streamId.contains("?") && streamId.contains("=")) {
//this means query parameters (token, hash etc.) are added to URL, so split it
if (streamId.contains("?")) {
streamId = streamId.split("\\?")[0];
}

if(!StreamIdValidator.isStreamIdValid(streamId))
{
boolean isValidStreamId = false;

//check for both / and %2F
String[] pathSegments = streamId.split("/|%2F");


if (pathSegments.length > 0 && !pathSegments[0].isEmpty()) {
isValidStreamId = StreamIdValidator.isStreamIdValid(pathSegments[0]);
}

if (!isValidStreamId) {
Status status = getStatus(NS_FAILED).asStatus();
status.setDescription(INVALID_STREAM_NAME+" setream name:"+streamId);
status.setDescription(INVALID_STREAM_NAME + " stream name: " + streamId);
channel.sendStatus(status);
return;
}
Expand Down Expand Up @@ -567,7 +578,7 @@ public boolean isAllowedIfRtmpPlayback(RTMPConnection conn, Channel channel, Str
}

public StatusObject getStatus(String code) {
return statusObjectService.getStatusObject(code);
return getStatusObjectService().getStatusObject(code);
}

/** {@inheritDoc} */
Expand Down
61 changes: 43 additions & 18 deletions src/main/java/org/red5/server/stream/StreamService.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.Map;
import java.util.Set;

import io.antmedia.websocket.WebSocketConstants;
import org.apache.commons.lang3.StringUtils;
import org.red5.io.utils.ObjectMap;
import org.red5.server.BaseConnection;
Expand Down Expand Up @@ -59,6 +60,8 @@
*/
public class StreamService implements IStreamService {

private static String[] urlKeys = {WebSocketConstants.STREAM_NAME, WebSocketConstants.TOKEN, WebSocketConstants.SUBSCRIBER_ID, WebSocketConstants.SUBSCRIBER_CODE};

private static Logger log = LoggerFactory.getLogger(StreamService.class);

/**
Expand Down Expand Up @@ -626,36 +629,58 @@ public void publish(Boolean dontStop) {
}
}


private Map<String, String> parseQueryParameters(String name) {
Map<String, String> params = new HashMap<>();
String tmp = name;
// check if we start with '?' or not
if (name.charAt(0) != '?') {
tmp = name.split("\\?")[1];
} else if (name.charAt(0) == '?') {
tmp = name.substring(1);
}
// now break up into key/value blocks
String[] kvs = tmp.split("&");
// take each key/value block and break into its key value parts
for (String kv : kvs) {
String[] split = kv.split("=");
params.put(split[0], split[1]);
}
return params;
}

// Handle the rtmp url format (e.g. /testStream/example_token/example_subscriberId/example_subscriberCode)
public Map<String, String> parsePathSegments(String name) {
Map<String, String> params = new HashMap<>();

// Split the name by both '/' and '%2F'
String[] pathSegments = name.split("/|%2F");

for (int i = 0; i < pathSegments.length && i < urlKeys.length; i++) {
params.put(urlKeys[i], pathSegments[i]);
}

return params;
}


/**
* {@inheritDoc}
* We have added "synchronized" because this method can be called exactly with the same names at the same time
* It creates an extra zombi scope that is not deleted anytime.
* It creates an extra zombi scope that is not deleted anytime.
* By synching this method, we prevent this problem.
*/
public synchronized void publish(String name, String mode) {

Map<String, String> params = null;
if (name != null && name.contains("?")) {
// read and utilize the query string values
params = new HashMap<String, String>();
String tmp = name;
// check if we start with '?' or not
if (name.charAt(0) != '?') {
tmp = name.split("\\?")[1];
} else if (name.charAt(0) == '?') {
tmp = name.substring(1);
}
// now break up into key/value blocks
String[] kvs = tmp.split("&");
// take each key/value block and break into its key value parts
for (String kv : kvs) {
String[] split = kv.split("=");
params.put(split[0], split[1]);
}
// grab the streams name
params = parseQueryParameters(name);
name = name.substring(0, name.indexOf("?"));
} else if (name != null && name.matches(".*[/%2F].*")) { // match / or %2F
params = parsePathSegments(name);
name = params.getOrDefault("streamName", name);
}

log.debug("publish called with name {} and mode {}", name, mode);
IConnection conn = Red5.getConnectionLocal();
if (conn instanceof IStreamCapableConnection) {
Expand Down
98 changes: 86 additions & 12 deletions src/test/java/org/red5/server/net/rtmp/ServerRTMPHandshakeTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.*;
import static org.red5.server.net.rtmp.status.StatusCodes.NS_FAILED;

import java.util.Arrays;
import java.util.Collection;
Expand All @@ -13,19 +15,33 @@
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;

import io.antmedia.StreamIdValidator;
import io.antmedia.statistic.IStatsCollector;
import org.apache.commons.codec.binary.Hex;
import org.apache.mina.core.buffer.IoBuffer;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.red5.io.object.StreamAction;
import org.red5.io.utils.IOUtils;
import org.red5.server.Context;
import org.red5.server.api.IContext;
import org.red5.server.api.scope.IScope;
import org.red5.server.api.service.IServiceCall;
import org.red5.server.api.service.IServiceInvoker;
import org.red5.server.api.stream.IClientStream;
import org.red5.server.api.stream.IStreamService;
import org.red5.server.net.ICommand;
import org.red5.server.net.rtmp.message.Header;
import org.red5.server.net.rtmp.status.Status;
import org.red5.server.net.rtmp.status.StatusObject;
import org.red5.server.net.rtmp.status.StatusObjectService;
import org.red5.server.scope.Scope;
import org.red5.server.stream.StreamService;
import org.red5.server.util.ScopeUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -262,18 +278,18 @@ public Boolean call() throws Exception {
public void testRTMPPlaybackAndAllowed() {
RTMPHandler rtmpHandler = new RTMPHandler();

RTMPConnection conn = Mockito.mock(RTMPConnection.class);
Channel channel = Mockito.mock(Channel.class);
RTMPConnection conn = mock(RTMPConnection.class);
Channel channel = mock(Channel.class);

assertTrue(rtmpHandler.isAllowedIfRtmpPlayback(conn, channel, StreamAction.CONNECT));
AppSettings appSettings = new AppSettings();
IScope scope = Mockito.mock(IScope.class);
Mockito.when(conn.getScope()).thenReturn(scope);
IScope scope = mock(IScope.class);
when(conn.getScope()).thenReturn(scope);

IContext context = Mockito.mock(IContext.class);
Mockito.when(scope.getContext()).thenReturn(context);
IContext context = mock(IContext.class);
when(scope.getContext()).thenReturn(context);

Mockito.when(context.getBean(Mockito.anyString())).thenReturn(appSettings);
when(context.getBean(Mockito.anyString())).thenReturn(appSettings);
StatusObjectService statusObjectService = new StatusObjectService();
statusObjectService.loadStatusObjects();
rtmpHandler.setStatusObjectService(statusObjectService);
Expand All @@ -294,19 +310,77 @@ public void testLogStreamNames() {
RTMPHandler rtmpHandler = new RTMPHandler();
rtmpHandler.logStreamNames(null);

RTMPConnection conn = Mockito.mock(RTMPConnection.class);
Collection<IClientStream> streams = Arrays.asList(Mockito.mock(IClientStream.class));
Mockito.when(conn.getStreams()).thenReturn(streams);
RTMPConnection conn = mock(RTMPConnection.class);
Collection<IClientStream> streams = Arrays.asList(mock(IClientStream.class));
when(conn.getStreams()).thenReturn(streams);
rtmpHandler.logStreamNames(conn);

RTMPMinaConnection rtmpConnection = new RTMPMinaConnection();
rtmpConnection.logWarning();

rtmpConnection.logStream(Mockito.mock(IClientStream.class));
rtmpConnection.logStream(mock(IClientStream.class));
}
catch (Exception e) {
fail(e.getMessage());
}
}

}
@Test
public void testParseSegmentParams(){
RTMPHandler rtmpHandler = spy(RTMPHandler.class);

RTMPConnection conn = mock(RTMPConnection.class);
StatusObjectService statusObjectService = mock(StatusObjectService.class);
StatusObject statusObject = mock(StatusObject.class);
Status status = mock(Status.class);
Channel channel = mock(Channel.class);
Header source = mock(Header.class);
ICommand command = mock(ICommand.class);
IServiceCall call = mock(IServiceCall.class);
Context context = mock(Context.class);
Scope scope = mock(Scope.class);
IStatsCollector resourceMonitor = mock(IStatsCollector.class);
IServiceInvoker serviceInvoker = mock(IServiceInvoker.class);
IStreamService streamService = mock(IStreamService.class);

when(command.getTransactionId()).thenReturn(1);
when(command.getCall()).thenReturn(call);
when(conn.isConnected()).thenReturn(true);
when(conn.getScope()).thenReturn(scope);
when(scope.getContext()).thenReturn(context);
when(context.hasBean(IStatsCollector.BEAN_NAME)).thenReturn(true);
when(context.getBean(IStatsCollector.BEAN_NAME)).thenReturn(resourceMonitor);
when(resourceMonitor.enoughResource()).thenReturn(true);

when(context.getServiceInvoker()).thenReturn(serviceInvoker);

when(serviceInvoker.invoke(call, streamService)).thenReturn(true);
when(rtmpHandler.getStatusObjectService()).thenReturn(statusObjectService);
when(statusObjectService.getStatusObject(NS_FAILED)).thenReturn(statusObject);
when(statusObject.asStatus()).thenReturn(status);

when(call.getServiceMethodName()).thenReturn("publish");

when(call.getArguments()).thenReturn(new Object[]{"testStream/token/subscriberId/subscriberCode"});

try (MockedStatic<ScopeUtils> mockedScopeUtils = mockStatic(ScopeUtils.class)) {
mockedScopeUtils.when(() -> ScopeUtils.getScopeService(any(), eq(IStreamService.class), eq(StreamService.class)))
.thenReturn(streamService);

rtmpHandler.onCommand(conn, channel, source, command);

when(call.getArguments()).thenReturn(new Object[]{"/testStream/token/subscriberId/subscriberCode"});
rtmpHandler.onCommand(conn, channel, source, command);

//invalid streamId
when(call.getArguments()).thenReturn(new Object[]{"@invalidStreamId!/token/subscriberId/subscriberCode"});
rtmpHandler.onCommand(conn, channel, source, command);

verify(rtmpHandler,times(1)).getStatus(NS_FAILED);
verify(channel,times(1)).sendStatus(status);
verify(rtmpHandler, times(2)).invokeCall(conn, call, streamService);

}

}
}
Loading

0 comments on commit f01c59e

Please sign in to comment.