付延余
2022-12-16 f0f8ee8c4a945adbc742d9bab69382b28ad311fb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
package com.wgcloud.common;
 
import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import com.jcraft.jsch.ChannelShell;
import com.jcraft.jsch.JSch;
import com.jcraft.jsch.JSchException;
import com.jcraft.jsch.Session;
import com.wgcloud.util.staticvar.StaticKeys;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.util.concurrent.GlobalEventExecutor;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
 
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Properties;
 
/**
 * @version v3.3
 * @ClassName:NettytHandler.java
 * @author: http://www.wgstart.com
 * @date: 2021年4月22日
 * @Description: netty处理handler
 * @Copyright: 2017-2021 wgcloud. All rights reserved.
 */
public class NettytHandler extends SimpleChannelInboundHandler<TextWebSocketFrame> {
 
    private static final Logger logger = LoggerFactory.getLogger(NettytHandler.class);
 
    //操作类型
    public static final String HANDLE_OPERATE = "handle";
    //操作指令
    public static final String HANDLE_VALUE = "value";
    //连接超时30s
    public static final Integer CONNECTION_OUT = 30000;
    //回车
    public static final String ENTER_VAL = "\r";
    //换行
    public static final String LINE_NEXT_VAL = "\n";
    //TAB制表符
    public static final String TAB_VAL = "\t";
 
    //所有正在连接的channel都会存在这里面,所以也可以间接代表在线的客户端
    public static ChannelGroup channelGroup = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);
    //客户端执行的命令
    public static Map<String, String> MAP_CMD = Collections.synchronizedMap(new HashMap<String, String>());
 
    //客户端保存的ssh连接session
    public static Map<String, ChannelShell> MAP_SSH_SESSION = Collections.synchronizedMap(new HashMap<String, ChannelShell>());
 
    @Override
    protected void channelRead0(ChannelHandlerContext ctx, TextWebSocketFrame msg) throws Exception {
        String channelId = ctx.channel().id().toString();
        String msgJSonStr = msg.text();
        if (StringUtils.isEmpty(msgJSonStr)) {
            return;
        }
        JSONObject msgJson = JSONUtil.parseObj(msgJSonStr);
        if ("connect".equals(msgJson.getStr(HANDLE_OPERATE))) {
            //执行连接操作
            try {
                getSSHChannel(channelId, msgJson.getStr("ip"), msgJson.getStr("user"), msgJson.getStr("pwd"),
                        Integer.valueOf(msgJson.getStr("port")), msgJson.getStr("priKeyBasePath"));
                executeCommand(ctx, ENTER_VAL);
            } catch (JSchException e) {
                ctx.writeAndFlush(new TextWebSocketFrame(LINE_NEXT_VAL + e.toString()));
                logger.error("ssh终端连接错误:", e);
            }
        } else if ("cmd".equals(msgJson.getStr(HANDLE_OPERATE))) {
            //发送指令执行
            String cmdStr = msgJson.getStr(HANDLE_VALUE);
            //判断指令是否包含TAB,若是则不做任何操作,直接返回
            if (cmdStr.endsWith(TAB_VAL)) {
                return;
            }
            //普通指令字符使用MAP_CMD存贮该终端的指令
            if (StringUtils.isEmpty(MAP_CMD.get(channelId))) {
                MAP_CMD.put(channelId, cmdStr);
            } else {
                MAP_CMD.put(channelId, MAP_CMD.get(channelId) + cmdStr);
            }
            //判断指令是否包含回车,若是则发送执行,执行完成清空MAP_CMD已存贮该终端的指令
            if (cmdStr.contains(ENTER_VAL)) {
                if (!StringUtils.isEmpty(MAP_CMD.get(channelId))) {
                    executeCommand(ctx, MAP_CMD.get(channelId));
                    MAP_CMD.remove(channelId);
                }
            }
        }
    }
 
    //客户端建立连接
    @Override
    public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
        channelGroup.add(ctx.channel());
        logger.info(ctx.channel().remoteAddress() + "ssh终端上线了!");
    }
 
    //关闭连接
    @Override
    public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
        channelGroup.remove(ctx.channel());
        logger.info(ctx.channel().remoteAddress() + "ssh终端断开连接");
        ChannelShell channelShell = MAP_SSH_SESSION.get(ctx.channel().id().toString());
        if (channelShell != null) {
            Session session = channelShell.getSession();
            if (channelShell != null) {
                channelShell.disconnect();
            }
            if (session != null) {
                session.disconnect();
            }
            MAP_SSH_SESSION.remove(ctx.channel().id().toString());
            MAP_CMD.remove(ctx.channel().id().toString());
        }
    }
 
    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        cause.printStackTrace();
        ctx.channel().close();
    }
 
    /**
     * ssh2连接服务器,并将channel存到map里
     *
     * @param channelId      终端ID
     * @param host           主机IP
     * @param user           用户名
     * @param password       密码
     * @param port           ssh端口
     * @param priKeyBasePath 私钥文件
     * @return
     */
    private static ChannelShell getSSHChannel(String channelId, String host, String user, String password, Integer port, String priKeyBasePath) throws JSchException {
        ChannelShell channelShell = MAP_SSH_SESSION.get(channelId);
        if (channelShell != null) {
            return channelShell;
        }
        JSch jsch = new JSch();
        Session session = jsch.getSession(user, host, port);
        if (!StringUtils.isEmpty(password)) {
            //密码登录
            session.setPassword(password);
            //忽略第一次连接时候 hostkey 检查
            session.setConfig("StrictHostKeyChecking", "no");
        } else {
            //添加私钥登录
            logger.debug("priKeyBasePath-----------" + StaticKeys.JAR_PATH + "/" + priKeyBasePath);
            jsch.addIdentity(StaticKeys.JAR_PATH + "/" + priKeyBasePath);
            Properties config = new Properties();
            config.put("StrictHostKeyChecking", "no");
            session.setConfig(config);
        }
        session.connect(CONNECTION_OUT);
        // 设置timeout时间
        session.setTimeout(600000);
        //开启shell,shell 具有上下文交互,执行命令不会马上退出
        channelShell = (ChannelShell) session.openChannel("shell");
        channelShell.connect(CONNECTION_OUT);
        channelShell.setPtyType("dumb");
        channelShell.setPty(true);
        MAP_SSH_SESSION.put(channelId, channelShell);
        return channelShell;
    }
 
    /**
     * 执行shell脚本指令
     *
     * @param ctx
     * @param cmds
     */
    private static void executeCommand(ChannelHandlerContext ctx, String cmds) {
        ctx.writeAndFlush(new TextWebSocketFrame(LINE_NEXT_VAL));
        try {
            ChannelShell channelShell = MAP_SSH_SESSION.get(ctx.channel().id().toString());
            if (null == channelShell) {
                ctx.writeAndFlush(new TextWebSocketFrame("\n\r~$ "));
                return;
            }
            InputStream inputStream = channelShell.getInputStream();
            OutputStream outputStream = channelShell.getOutputStream();
            outputStream.write((cmds).getBytes());
            outputStream.flush();
            if (!StringUtils.isEmpty(cmds)) {
                logger.info(channelShell.getSession().getHost() + "," + channelShell.getSession().getUserName() + ",执行ssh指令:" + cmds);
            }
            String showMsg = "";
            Thread.sleep(400);
            byte[] tmp = new byte[4096];
            //读取shell命令执行结果循环次数,最大3,也就是每次读取执行结果等待1500ms,此期间若没有返回数据就会退出
            int beat = 0;
            while (true) {
                //循环读取shell执行结果  begin
                while (inputStream.available() > 0) {
                    int i = inputStream.read(tmp);
                    if (i < 0) {
                        break;
                    }
                    showMsg = new String(tmp, 0, i, "utf-8");
                    ctx.writeAndFlush(new TextWebSocketFrame(showMsg));
                    Thread.sleep(500);
                }
                //循环读取shell执行结果 end
 
                Thread.sleep(500);
 
                //判断shell通道是否关闭 begin
                if (channelShell.isClosed()) {
                    //查询是否还有可读取的字节,若有则继续循环读取
                    if (inputStream.available() > 0) {
                        continue;
                    }
                    ctx.writeAndFlush(new TextWebSocketFrame("exit-status: " + channelShell.getExitStatus() + ",会话通道已超时请重新连接"));
                    break;
                }
                //判断shell通道是否关闭 end
 
                beat++;
                //循环3次,读取结束
                if (beat > 2) {
                    break;
                }
            }
        } catch (Exception e) {
            logger.error("shh指令执行错误:", e);
        }
    }
 
    /**
     * 清空web ssh终端缓存map
     */
    public static void clearOldData() {
        MAP_CMD.clear();
        MAP_SSH_SESSION.clear();
    }
}